Git Product home page Git Product logo

ddpae-video-prediction's Introduction

DDPAE

This is the code for the paper

Learning to Decompose and Disentangle Representations for Video Prediction
Jun-Ting Hsieh, Bingbin Liu, De-An Huang, Li Fei-Fei, Juan Carlos Niebles

Will be presented at NeurIPS 2018.

Our goal is to predict future video frames given a sequence of input frames. Our key insight into this high-dimensional prediction problem is to decompose it into sub-problems that can be more easily predicted. To this end, we proposed the Decompositional Disentangled Predictive Auto-Encoder (DDPAE), a framework that combines structured probabilistic models and deep networks to automatically

  1. decompose the high-dimensional video into components, and
  2. disentangle each component to have low-dimensional temporal dynamics that are easier to predict.

Our DDPAE is able to learn both the latent decomposition and disentanglement without explicit supervision. We show strong results on the Moving MNIST and Bouncing Balls datasets.

Note: Our code is built with Pyro. We encourage readers to read the Pyro tutorial and AIR tutorial to better understand our code.

If you find this code useful in your research then please cite:

@inproceedings{hsieh2018learning,
  title={Learning to Decompose and Disentangle Representations for Video Prediction},
  author={Hsieh, Jun-Ting and Liu, Bingbin and Huang, De-An and Fei-Fei, Li and Niebles, Juan Carlos},
  booktitle={NeurIPS},
  year={2018}
}

Setup

All code was developed and tested on Python 3.6, PyTorch 0.3.1, and Pyro 0.2.
To view results, you need to install tensorflow and tensorbaordX.

Datasets

Moving MNIST

Download the original MNIST dataset and the Moving MNIST test set by running

./datasets/moving_mnist/download.sh

Bouncing Balls

We generate our Bouncing Balls dataset with the Neural Physics Engine, used by Chang et al, 2017.

First, clone the bouncing_balls submodule,

git submodule update --init --recursive

Note that this is slightly modified from the original dynamics repository.

npm must be installed. With Anaconda, install by running: conda install -c conda-forge nodejs. Then,

cd datasets/bouncing_balls/src/js/
npm install

Generate training and testing data,

node demo/js/generate.js -e balls -n 4 -t 60 -s 50000
node demo/js/generate.js -e balls -n 4 -t 60 -s 2000

This will generate two folders, balls_n4_t60_ex50000 and balls_n4_t60_ex2000, in the datasets/bouncing_balls/data/ directory. Move these folders to the desired location (dataset path). Next, modify the root variable in datasets/bouncing_balls/process.py and run python datasets/bouncing_balls/process.py for both files.

Training

Remember to set the arguments, such as ckpt_dir, ckpt_name, etc. A tb_train directory will be created in ckpt_dir. You can see the training progress by running tensorboard --logdir ckpt_dir/tb_train.

Moving MNIST

Run

./scripts/train_mnist.sh

In our paper, we let the model train for 200k iterations. However, 50k iterations should be enough to see the model separating the two digits.

Moving MNIST

Run

./scripts/train_balls.sh

We also train for 200k iterations.

ddpae-video-prediction's People

Contributors

jthsieh avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.