Git Product home page Git Product logo

subtab's Introduction

SubTab:

Author: Talip Ucar ([email protected])

The official implementation of the paper,

SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning

PWC

๐Ÿ”ถ Note: The extended version of SubTab with codes and pre-processed data for Adult Income and BlogFeedback datasets can be found at: https://github.com/talipucar/SubTab_extended

Table of Contents:

  1. Model
  2. Environment
  3. Data
  4. Configuration
  5. Training and Evaluation
  6. Adding New Datasets
  7. Results
  8. Experiment tracking
  9. Citing the paper
  10. Citing this repo
NeurIPS 2021 slides NeurIPS 2021 poster
NeurIPS 2021 slides NeurIPS 2021 poster

Model

SubTab

Click for a slower version of the animation

SubTab

Environment

We used Python 3.7 for our experiments. The environment can be set up by following three steps:

pip install pipenv             # To install pipenv if you don't have it already
pipenv install --skip-lock     # To install required packages. 
pipenv shell                   # To activate virtual env

If the second step results in issues, you can install packages in Pipfile individually by using pip i.e. "pip install package_name".

Data

MNIST dataset is already provided to demo the framework. For your own dataset, follow the instructions in Adding New Datasets.

Configuration

There are two types of configuration files:

1. runtime.yaml
2. mnist.yaml
  1. runtime.yaml is a high-level configuration file used by all datasets to:

    • define the random seed
    • turn on/off mlflow (Default: False)
    • turn on/off python profiler (Default: False)
    • set data directory
    • set results directory
  2. Second configuration file is dataset-specific and is used to configure the architecture of the model, loss functions, and so on.

    • For example, we set up a configuration file for MNIST dataset with the same name. Please note that the name of the configuration file should be same as name of the dataset with all letters in lowercase.
    • We can have configuration files for other datasets such as tcga.yaml and income.yaml for tcga and income datasets respectively.

Training and Evaluation

You can train and evaluate the model by using:

python train.py # For training. 
python eval.py  # For evaluation
  • train.py will also run evaluation at the end of the training.
  • You can also run evaluation separately by using eval.py.
  • For a list of arguments, please see ./utils/arguments.py
    • Use -h argument to get help when running scripts.
    • Use -d dataset_name to run scripts on new datasets

Adding New Datasets

For each new dataset, you can use the following steps:

  1. Provide a _load_dataset_name() function, similar to MNIST load function

    • For example, you can add _load_tcga() for tcga dataset, or _load_income() for income dataset.
    • The function should return (x_train, y_train, x_test, y_test)
  2. Add a separate elif condition in this section within _load_data() method of TabularDataset() class in utils/load_data.py

  3. Create a new config file with the same name as dataset name.

    • For example, tcga.yaml for tcga dataset, or income.yaml for income dataset.

    • You can also duplicate one of the existing configuration files (e.g. mnist.yaml), and re-name it.

    • Make sure that the new config file is under config/ directory.

  4. Provide data folder with pre-processed training and test set, and place it under ./data/ directory. You can also do train-test split and pre-processing within your custom _load_dataset_name() function.

  5. (Optional) If you want to place the new dataset under a different directory than the local "./data/", then:

    • Place the dataset folder anywhere, and define the root directory to it in this line of /config/runtime.yaml.

    • For example, if the path to tcga dataset is /home/.../data/tcga/, you only need to include /home/.../data/ in runtime.yaml. The code will fill in tcga folder name from the name given in the command line argument (e.g. -d dataset_name. In this case, dataset_name would be tcga).

Structure of the repo

- train.py
- eval.py

- src
    |-model.py
    
- config
    |-runtime.yaml
    |-mnist.yaml
    
- utils
    |-load_data.py
    |-arguments.py
    |-model_utils.py
    |-loss_functions.py
    ...
    
- data
    |-mnist
    ...
    
- results
    |
    ...

Results

Results at the end of training is saved under ./results directory. Results directory structure is as following:

- results
    |-dataset name
            |-evaluation
                |-clusters (for plotting t-SNE and PCA plots of embeddings)
                |-reconstructions (not used)
            |-training
                |-model_mode (e.g. ae for autoencoder)   
                     |-model
                     |-plots
                     |-loss

You can save results of evaluations under "evaluation" folder.

Experiment tracking

MLFlow is used to track experiments. It is turned off by default, but can be turned on by changing option on this line in runtime config file in ./config/runtime.yaml

Citing the paper

@article{ucar2021subtab,
  title={SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning},
  author={Ucar, Talip and Hajiramezanali, Ehsan and Edwards, Lindsay},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  year={2021}
}

Citing this repo

If you use SubTab framework in your own studies, and work, please cite it by using the following:

@Misc{talip_ucar_2021_SubTab,
  author =   {Talip Ucar},
  title =    {{SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning}},
  howpublished = {\url{https://github.com/AstraZeneca/SubTab}},
  month        = June,
  year = {since 2021}
}

subtab's People

Contributors

talipucar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

subtab's Issues

Using only contrastive loss on MNIST error

Firstly, thank you for this great paper.

I would like to point out that when I am trying to train the model with the MNIST dataset where I keep everything the same as your initial implementation but turn reconstruction and distance loss off (set as false in configuration file) and keep only Contrastive loss as true, I get the following error as soon as my model starts training.

Traceback (most recent call last):
File "train.py", line 98, in
run_with_profiler(main, config) if config["profile"] else main(config)
File "train.py", line 71, in main
train(config, ds_loader, save_weights=True)
File "train.py", line 34, in train
model.fit(data_loader)
File "C:\SubTab\src\model.py", line 116, in fit
self.update_autoencoder(x_tilde_list, Xorig)
File "C:\SubTab\src\model.py", line 233, in update_autoencoder
tloss, closs, rloss, zloss = self.joint_loss(z, Xrecon, Xorig)
File "C:.conda\envs\modelling-dev\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "C:\SubTab\utils\loss_functions.py", line 144, in forward
recon_loss = getMSEloss(xrecon, xorig) if self.options["reconstruction"] else getBCELoss(xrecon, xorig)
File "C:\SubTab\utils\loss_functions.py", line 38, in getBCELoss
return F.binary_cross_entropy(prediction, label, reduction='sum') / bs
File "C:.conda\envs\modelling-dev\lib\site-packages\torch\nn\functional.py", line 3065, in binary_cross_entropy
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)

RuntimeError: all elements of input should be between 0 and 1

After printing the first input I notice that some predictions have small negative values which breaks the calculation of the Binary cross entropy loss.

ex.
ensor([[ 0.0011, 0.0684, 0.0063, ..., 0.0269, -0.1042, 0.0020],
[-0.0473, -0.0807, 0.0676, ..., -0.0229, -0.1546, 0.0397],
[-0.0291, 0.0879, 0.0631, ..., 0.0322, -0.0247, 0.0595],
...,
[ 0.0105, 0.1674, 0.0220, ..., -0.0655, -0.1474, -0.1166],
[-0.0297, 0.0714, 0.0102, ..., -0.0438, -0.0500, 0.0241],
[ 0.0124, 0.1329, 0.0307, ..., -0.0530, -0.0293, 0.0392]],
grad_fn=)

Is there something I am doing wrong? Thank you in advance for your help.

swap noise, encoding and loss related to implementation

For tabular datasets containing categorical features, do you use swap noise and subset selection before one-hot encoding? Do you do swap noise before one-hot encoding as well?
Do you use different losses for categorical, numerical, and binary features? what about activation functions? is it different activation functions for different feature types?

Recent papers/works that apply self-supervised and semi-supervised settings to the tabular datasets, just have shared code for MNIST dataset. No one shares the code for tabular datasets. Could you please share your code for just one dataset containing categorical, numerical features?

How do you add gaussian noise to categorical features

I find that you use onehot_encoder to process categorical features. However, one feature after onehot become to [ 0,xxx,0,1,0,xxx,0]. Do you add gaussian noise to all elements? I'm not clear on how you do it here.

Package version conflict in requirements.txt

Thanks a lot for open source the code of this paper. It's really interesting!
I encountered a problem when installing the packages by requirements.txt.

ERROR: Cannot install -r requirements.txt (line 2) and chardet==4.0.0 because these package versions have conflicting dependencies.

The conflict is caused by:
The user requested chardet==4.0.0
aiohttp 3.7.3 depends on chardet<4.0 and >=2.0

My python version is 3.7.1. Had anyone had fixed this conflict before? Thank you.

Other datasets

Will you provide code for other datasets in your paper( Income Blog Obesity TCGA) in the future?

Suitable for Regression?

The pretraining process doesn't use label information, so I guess SubTab also apply for regression like classification? Or anywhere needs to be modified?

Please check code `def __init__(self, options)` at `HiddenLayers` in `utils/model_utils.py`

I think HiddenLayers class code have wrong line.

This is Wrong code :

class HiddenLayers(nn.Module):
    def __init__(self, options):
        (...omission...)
        for i in range(1, len(dims) - 1): # <--- It's wrong code.
            ....

I think you need to correct code like below.

class HiddenLayers(nn.Module):
    def __init__(self, options):
        (...omission...)
        for i in range(1, len(dims)): # <--- It's correct code.
            ....

Please check this code and I want your feedback.

Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.