Git Product home page Git Product logo

causaltransformer's Introduction

CausalTransformer

Conference arXiv Python application

Causal Transformer for estimating counterfactual outcomes over time.

Screenshot 2022-06-03 at 16 41 44

The project is built with following Python libraries:

  1. Pytorch-Lightning - deep learning models
  2. Hydra - simplified command line arguments management
  3. MlFlow - experiments tracking

Installations

First one needs to make the virtual environment and install all the requirements:

pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt

MlFlow Setup / Connection

To start an experiments server, run:

mlflow server --port=5000

To access MlFLow web UI with all the experiments, connect via ssh:

ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>

Then, one can go to local browser http://localhost:5000.

Experiments

Main training script is universal for different models and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in configs/ folder.

Generic script with logging and fixed random seed is following (with training-type enc_dec, gnet, rmsn and multi):

PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices> 
python3 runnables/train_<training-type>.py +dataset=<dataset> +backbone=<backbone> exp.seed=10 exp.logging=True

Backbones (baselines)

One needs to choose a backbone and then fill the specific hyperparameters (they are left blank in the configs):

Models already have best hyperparameters saved (for each model and dataset), one can access them via: +backbone/<backbone>_hparams/cancer_sim_<balancing_objective>=<coeff_value> or +backbone/<backbone>_hparams/mimic3_real=diastolic_blood_pressure.

For CT, EDCT, and CT, several adversarial balancing objectives are available:

  • counterfactual domain confusion loss (this paper): exp.balancing=domain_confusion
  • gradient reversal (originally in CRN, but can be used for all the methods): exp.balancing=grad_reverse

To train a decoder (for CRN and RMSNs), use the flag model.train_decoder=True.

To perform a manual hyperparameter tuning use the flags model.<sub_model>.tune_hparams=True, and then see model.<sub_model>.hparams_grid. Use model.<sub_model>.tune_range to specify the number of trials for random search.

Datasets

One needs to specify a dataset / dataset generator (and some additional parameters, e.g. set gamma for cancer_sim with dataset.coeff=1.0):

  • Synthetic Tumor Growth Simulator: +dataset=cancer_sim
  • MIMIC III Semi-synthetic Simulator (multiple treatments and outcomes): +dataset=mimic3_synthetic
  • MIMIC III Real-world dataset: +dataset=mimic3_real

Before running MIMIC III experiments, place MIMIC-III-extract dataset (all_hourly_data.h5) to data/processed/

Example of running Causal Transformer on Synthetic Tumor Growth Generator with gamma = [1.0, 2.0, 3.0] and different random seeds (total of 30 subruns), using hyperparameters:

PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices> 
python3 runnables/train_multi.py -m +dataset=cancer_sim +backbone=ct +backbone/ct_hparams/cancer_sim_domain_conf='0','1','2' exp.seed=10,101,1010,10101,101010

Project based on the cookiecutter data science project template. #cookiecutterdatascience

causaltransformer's People

Contributors

samholt avatar valentyn1997 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.