Git Product home page Git Product logo

multimodal-dmm's Introduction

Multimodal Deep Markov Models

A PyTorch implementation of the Multimodal Deep Markov Model (MDMM) and associated inference methods described in Factorized Inference in Deep Markov Models for Incomplete Multimodal Time Series. Please cite this paper if you use or modify any of this code.

Generalizes the Multimodal Variational Auto-Encoder (MVAE) by Wu & Goodman and the Deep Markov Model by Krishnan et al.

Setup

After creating a virtual environment with virtualenv or conda, one can simply install the dependencies in requirements.txt. Compatible with both Python 2.7 and Python 3.

virtualenv venv
source venv/bin/activate
pip install -r requirements.txt

Alternatively, one can install the following packages directly through pip:

# For basic functionality
pip install torch==1.1.0 pandas pyyaml matplotlib

# To download and pre-process the Weizmann video dataset
sudo apt-get install ffmpeg
pip install scipy scikit-video scikit-image requests tqdm opencv-python

# To run the experiment scripts using Ray Tune
pip install ray psutil

Datasets

Before training, the datasets need to be generated or downloaded.

To generate the Spirals dataset, make datasets the current directory, then run python spirals.py. For a list of options, run python spirals.py -h.

To automatically download and preprocess the Weizmann video dataset of human actions, again make sure that datasets is the current directory, then run python weizmann.py.

If automated download fails, create a directory called weizmann in datasets, and download the zip files and segmentation masks from the Weizmann dataset website.

Models and Inference Methods

The models subdirectory contains three different inference methods that can be used with MDMM (or MDMM-like) architectures:

  • dmm.py implements the MDMM with Backward Forward Variational Inference (BFVI), as described in our paper. Refer to the included docstrings for a full list of options.

  • dks.py implements the MDMM with the RNN-based structured inference networks described by Krishnan et al. By providing different options to the constructor, one can use either forward or backward RNN networks, and toggle different methods for handling missing data. Refer to the docstrings for details.

  • vrnn.py implements a multimodal version of the Variational Recurrent Neural Network (VRNN) described by Chung et al. This is similar to using dks.py with a forward RNN.

Training

The training code for the Spirals dataset can be run by calling: python spirals.py Default hyper-parameters are used, run python spirals.py -h for a full list of options.

The training code for the Weizmann dataset can be run by calling: python weizmann.py Again, default hyper-parameters are used, run python weizmann.py -h for a full list of options.

To specify which inference method to use, use the --model flag with either dmm or dks. To specify which modalities to load and train on, use the --modalities flag. To visualize predictions while training, add the --visualize flag. Pretrained models can be evaluated by adding --load PATH/TO/MODEL.

An abstract Trainer class can be found in trainer.py, allowing training code to easily written for other multimodal sequential datasets.

Experiments

Ray Tune can be used to easily run experiments across multiple sets of hyper-parameters over multiple trials. Make sure ray is installed for this to work. Install tensorboard and tensorflow as well if you would like to visualize the loss curves via Tensorboard.

Comparing different inference methods on a range of tasks

For the Spirals dataset: python -m experiments.spirals_suite --trial_cpus N --trial_gpus N

For the Weizmann dataset: python -m experiments.weizmann_suite --trial_cpus N --trial_gpus N

Learning with uniformly random missing data

For the Spirals dataset: python -m experiments.spirals_partial --trial_cpus N --trial_gpus N

For the Weizmann dataset: python -m experiments.weizmann_partial --trial_cpus N --trial_gpus N

Semi-supervised learning

Semi-supervised learning refers to learning where some sequences have entire modalities removed.

For the Spirals dataset: python -m experiments.spirals_semisup --trial_cpus N --trial_gpus N

For the Weizmann dataset: python -m experiments.weizmann_semisup --trial_cpus N --trial_gpus N

Examples

Below are spiral reconstructions produced by the inference methods across different inference tasks. BFVI (our method) consistently produces good reconstructions across all inference tasks, unlike the RNN-based methods.

Comparison of spiral reconstructions

Video reconstructions from the Weizmann dataset are shown below, comparing BFVI to the next best method (B-Skip). Only video data is provided; the silhoutte masks and action labels have to be inferred. Again, it can be seen that BFVI produces better reconstructions, as well as better silhouette and action predictions.

Comparison of video reconstructions

Refer to the paper for more examples.

Bugs & Questions

Feel free to raise issues, or email xuan [at] mit [dot] edu with questions.

multimodal-dmm's People

Contributors

dependabot[bot] avatar ztangent avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

multimodal-dmm's Issues

Bug report

I use the latest pytorch with version 1.2.0, and find out the follow syntax (about 5 places occured) is not supported anymore. And you adopt unit8 as the dtype of a mask.

        mask_m = 1 - torch.isnan(inputs[m]).flatten(2,-1).any(dim=-1)

I have fixed it in my PC. Can you change the mask with a bool type for a better reproducing?

       mask_m = ~ torch.isnan(inputs[m]).flatten(2,-1).any(dim=-1)

Thanks

A question about Algorithm 2

Hello, Zhi-Xuan. This work is quite nice and the theory is good. I have a question about the details about Eqn(6) and the approximation Eqn(12).
Your choice is (as well as Algorithm 2):
approximation1
The common choice is:
approximation2

I think your choice is very beneficial and I argue this choice could improve accuracy than the common one. The p(z_t|x_{1:T}) is better than p(z_t|x_{1:t-1}) due to higher accurate estimation, but math relationship is also changed. Can you help me with further clarification?

Thanks

Slow training solution

Hi, your cost function nll(negative log likelihood) and kld (KL divergence) have some training problems (both gradient vanish and gradient explode) .
Your nll:
image
The gradients:
image
image
The gradient explosion seems to be obvious, and my experiment also demonstrated this loss has gradient vanishing problem if the input x far from the mean value.

If we change it without chaning the meaning:
image

After changing it, I fount that it can be trained much faster (500/3000 epoches config is very large number for your original setting). If you changing the loss, you can train the model without any annealing trick, it will not explode from the loss section.

Maybe, it helps. Thank you

Bug

Quotients of Gaussian Distribution may not correctly implemented.
You use following

dmm.py->forward()->#smooth pass-> inv_std = -inv_std.
But your function <product_of_experts()> cann't perform correctly.

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.