Git Product home page Git Product logo

synthvae's Introduction

Synthetic Data Exploration: Variational Autoencoders

NHSX Analytics Unit - PhD Internship Project

About the Project

status: experimental

This repository holds code for the NHSX Analytics Unit PhD internship project (previously known as Synthetic Data Generation - VAE) contextualising and investigating the potential use of Variational AutoEncoders (VAEs) for synthetic health data generation initially undertaken by Dominic Danks (last commit to the repository: commit 88a4bdf). This has then been further extended through the work undertaken by David Brind.

Project Description - Synthetic Data Exploration: Variational Autoencoders

Note: No data, public or private are shared in this repository.

Project Stucture

  • The main code is found in the root of the repository (see Usage below for more information)
  • The accompanying report is also available in the reports folder
  • More information about the VAE with Differential Privacy can be found in the model card
  • scratch_vae_expts.py is similar to the files you will find within the investigations folder. To re-run our results then scratch_vae_expts.py is all you require. If you want an easy way to understand our code and work process then using the respective notebooks within the investigations folder helps to run through the work.

N.B. A modified copy of Opacus (v0.14.0), a library for training PyTorch models with differential privacy, is contained within the repository. A modified copy of RDT (v0.6.2) that allows us to set the random_state for our GMM transformer method is also contained within the repository. See the model card for more details.

Built With

Python v3.8

Getting Started

Installation

To get a local copy up and running follow these simple steps.

To clone the repo:

git clone https://github.com/nhsx/SynthVAE.git

To create a suitable environment:

  • python -m venv synth_env
  • source synth_env/bin/activate
  • pip install -r requirements.txt
  • pip uninstall rdt (SDV installs rdt by default however we have included an added fix so this needs to be uninstalled to avoid conflicts)

GPU Support

This code has been tested both on CPU in the torch v1.9.0 given. But it has also been run on a GPU environment. The specifications for the device running this are as follows:

  • NVIDIA GeForce RTX 3070 Laptop GPU
  • CUDA v11.1
  • cuDNN v8.1.0 for CUDA v11.1

Refer to PyTorch documentation under v1.9.0 for wheel linux & windows CUDA 11.1 for the pip install required.

We use:

  • 'pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html' after using the pip install on requirements file in order to get the gpu supported versions of PyTorch that we require.

Note that the model used is a simple MLP encoder/decoder and we are working with tabular data. Therefore the speed up provided by GPU is not that noticeable and it is easy to train using the CPU version for people without GPU support.

Usage

There are three main sections of interest (2 folders and a collation of python files): Hyperparameter_Tuning, Investigations and files containing sdv_baselines.py, scratch_vae_expts.py and plot.py. Hyperparameter_Tuning and Investigations can use either SUPPORT or internal MIMIC datasets. sdv_baselines.py, scratch_vae_expts.py and plot.py all use SUPPORT only. If you want to run these quickly then below outlines command line implementation.

SDV Baselines

To reproduce the experiments contained in the report involving the SDV baseline models (e.g. CopulaGAN, CTGAN, GaussianCopula and TVAE), run sdv_baselines.py. The parameters can be found using the --help flag:

python sdv_baselines.py --help

usage: sdv_baselines.py [-h] [--n_runs N_RUNS] [--model_type {CopulaGAN,CTGAN,GaussianCopula,TVAE}] [--pre_proc_method {GMM, standard}] [--save_metrics {True, False}] [--gower {True, False}]

optional arguments:
  -h, --help            show this help message and exit
  --n_runs N_RUNS       set number of runs/seeds
  --model_type          {CopulaGAN,CTGAN,GaussianCopula,TVAE}
                        set model for baseline experiment
  --pre_proc_method     {GMM, standard}
                        set the pre-processing method
  --save_metrics        {True, False}        
                        set if you wish to save the metrics for this model run - saves default as Metric Breakdown.csv unless changed
  --gower               {True, False}
                        calculate the average gower distance

Scratch VAE + Differential Privacy

To reproduce the experiments contained in the report involving the VAE with/without differential privacy, run scratch_vae_expts.py. The parameters can be found using the --help flag:

python scratch_vae_expts.py --help

usage: scratch_vae_expts.py [-h] [--n_runs N_RUNS] [--diff_priv DIFF_PRIV] [--n_epochs N_EPOCHS] [--save_file {True, False}] [--save_visualisation {True, False}] [--save_metrics {True, False}] [--pre_proc_method {GMM, standard}] [--gower {True, False}]

optional arguments:
  -h, --help            show this help message and exit
  --n_runs N_RUNS       set number of runs/seeds
  --diff_priv DIFF_PRIV
                        run VAE with differential privacy
  --n_epochs N_EPOCHS   number of epochs to train for (early stopping is used by default)
  --save_file           {True, False}
                        save trained model's state_dict as 'trained_SynthVAE.pt'
  --save_visualisation  {True, False}
                        save the loss & variable plots
  --save_metrics        {True, False}
                        calculate and save the metrics of the training set
  --pre_proc_method     {GMM, standard}
                        set the pre-processing method
  --gower               {True, False}
                        calculate the average gower distance

Code to load a saved model and generate correlation heatmaps is contained within plot.py. The file containing the save model's state_dict should be provided via a command line argument:

python plot.py --help

usage: plot.py [-h] [--save_file SAVEFILE] [--pre_proc_method {GMM, standard}]

optional arguments:
  -h, --help           show this help message and exit
  --save_file SAVEFILE  load trained model's state_dict from file
  --pre_proc_method    {GMM, standard}
                       specify the pre-processing method that you wish to employ

Outputs Of The Model

There are a selection of plots & metrics the model can output. These are given by parse arguments --save_metrics, --save_visualisation or functions plot_elbo, plot_likelihood_breakdown, plot_variable_distributions & distribution_metrics in the notebooks provided. These outputs give you a graph displaying the ELBO breakdown during training, the breakdown of categorical & numerical likelihoods, a comparison of variable distributions between original data & synthetic data for each variable as well as a csv file displaying all the distributional metrics from SDV.

ELBO Breakdown

Likelihood Breakdown

Example Continuous Variable Comparison

Example Discrete Variable Comparison

The distributional metrics produces a csv following this order - depending on number of runs:

SVCDetection GMLogLikelihood CSTest KSTestExtended KSTestExtended ContinuousKLDivergence DiscreteKLDivergence
0.32 -520.01 0.91 0.85 0.87 0.91 0.97
0.31 -523.21 0.90 0.86 0.88 0.92 0.99

For information on these metrics then look in the report as well as the SDV single table metrics documentation

Note On Reproducibility Of Results

In order to get reproducible results we have added in the random_state argument to the RDT transformers in order to set the sklearn's random_state argument. This results in the GMM pre-processing method producing the same transformation each run for the same dataset. We also set the PyTorch seed at the top of each file using the set_seed function found in utils. If you do not wish to nullify the randomness in training then remove this set_seed line at the start of the files.

Dataset

Experiments are run against the Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT) dataset accessed via the pycox python library. Further experiments to test scalability of model were also performed on a pre-processed single table extracted from MIMIC-III dataset. The pre-processing to access this single table can be found within the SynthVAE files.

Your dataset should follow a simple structure as shown in the example table - continuous, categorical & datetime variables with no missingness or NaN values. Number of columns can be as many as required along with as many rows as required.

Roadmap

See the open issues for a list of proposed features (and known issues).

Contributing

Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are greatly appreciated.

  1. Fork the Project
  2. Create your Feature Branch (git checkout -b feature/AmazingFeature)
  3. Commit your Changes (git commit -m 'Add some AmazingFeature')
  4. Push to the Branch (git push origin feature/AmazingFeature)
  5. Open a Pull Request

See CONTRIBUTING.md for detailed guidance.

License

Distributed under the MIT License. See LICENSE for more information.

Contact

To find out more about the Analytics Unit visit our project website or get in touch at [email protected].

synthvae's People

Contributors

danjscho avatar davebrind avatar dependabot[bot] avatar djdnx avatar jrpearson500 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

synthvae's Issues

Unable to reproduce the report's distribution metrics using SUPPORT

Hello,

I am trying to reproduce the distribution metrics established using SUPPORT, as stated on page 13 of the SynthVAE report.

I have downloaded available code and checked that my libraries are identical to those given in the requirements.txt file. I am using Python version 3.8.0.

I have ran the following code (for both pre-processing methods) on windows in command prompt:

python scratch_vae_expts.py --pre_proc_method GMM

and:

python scratch_vae_expts.py --pre_proc_method Standard

I wasn't clear on which pre-processing method was used in the report. However, in both cases regardless, the distribution metrics that I have computed for the VAE model are different to those stated in the pdf. Please can you help suggest how to fix this? I have not modified the available code in any way. Perhaps the issue is due to seeding?

Thank you in advance.

Handling datetime generation

At the moment SDMETRICS gives many warnings originating from line 40 in their utils file https://github.com/sdv-dev/SDMetrics/blob/master/sdmetrics/utils.py. This function calculates frequencies of values in original data & synthetic data and throws warnings when values in synthetic data cannot be found in the real data.

It seems to do this a lot for datetime columns and this is probably due to the generation calculating them down to the smallest unit of time. Need to look at potentially truncating datetime columns to match whatever the unit in the input column is.

At the moment this is just handled by suppressing the warnings.

Datetime handling in VAE

Describe the bug
When passing datetime columns, the reverse transform does not work as expected and throws an error.

To Reproduce
Steps to reproduce the behavior:

  1. Supply a datetime column in the input dataframe
  2. Run training of scratch VAE
  3. Error thrown when trying to inverse transform datetime back

Expected behavior
Should be able to take and handle datetime column as continuous, not discrete

Other details
Possible fix is by changing order of HyperTransformer and StandardScaler in output transformations

Poor correlations when using `GMM` preprocessing

Describe the bug
When running the GMM preprocessing, we are seeing poor correlations coming out of the trained VAE. Investigate this further.

To Reproduce
Steps to reproduce the behavior:

  1. Run any training setup involving GMM preprocessing
  2. Compare correlations from original and generated data via something like .corr()

Expected behavior
Some closer matching of correlations of between generated data and original data when using GMM preprocessing

Additional context
Comparison to TVAE in CTGAN from SDV is worth utilising (https://github.com/sdv-dev/CTGAN)

Expand metric options

Currently privacy and fairness metrics are lacking - would be good to deepen the library of metrics available to the user such that they can get a complete breakdown of SynthVAE's ability to handle these two critical issues

Handling Missing Data

Currently SynthVAE cannot handle missing data in training sets - would be interesting in future to look at ways to amend this

Update Change log for minor fixes

Describe the bug
A clear and concise description of what the bug is.

To Reproduce
Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior
A clear and concise description of what you expected to happen.

Additional context
Add any other context about the problem here.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.