Git Product home page Git Product logo

multimodal-vae's Introduction

Multimodal VAE - PyTorch

Codacy Badge

This repository contains PyTorch implementation of the paper "Multimodal representation models for prediction and control from partial information".

Download data and process

python get_drive_file.py 1Nn-ONccUbW1cBwm6nRhgF-zjtKoB8zO6 data2020.zip
unzip data2020.zip
# see data.yml definition below
python prepare_data.py -opts data.yml
rm -r data2020
rm data2020.zip

Data folder structure

/mydataset
    /0 (trajectory)
        /action1 (action label)
            0.jpeg (only jpeg for now)
            1.jpeg
            ...
            N.jpeg (N can be any number)
            objects_0.txt [N x D] matrix
            anything_0.txt [N x K] matrix
        /action2 (optional)
    /1
        /action1
            0.jpeg
            1.jpeg
            ...
            P.jpeg (can be different from N above)
            objects_1.txt [P x D] matrix
            anything_1.txt [P x K] matrix
    ...
    /M

In order to pre-process the data and run the training script, you will need two yaml option files. Here are examples:

Example data.yml

path: "mydataset"
actions: ["action1"]
modality: ["img", "anything", "objects"]  # first modality should be always img
N: 10
sp_tr: 7  # train split
sp_vl: 8  # validation split
shuffle: false  # whether to shuffle trajectories

Example opts.yml

save: save/test
data: data_folder_path
device: cuda
modality: ["img", "anything", "objects"]
action: ["action1"]
batch_size: 128
epoch: 10
lambda: 1.0
beta: 0.0
init_method: xavier
lr: 0.0005
reduce: true
mse: true
beta_decay: 0.0
in_blocks: [
  [-2, 1024, 128, 6, 32, 64, 64, 128, 128, 256],  # image encoder
  [-1, 28, 32, 64, 64, 128, 128, 256, 128],  # anything encoder
  [-1, 16, 32, 64, 64, 128, 128, 256, 128]  # objects encoder
]
in_shared: [384, 256]  # shared encoder
out_shared: [128, 384]  # shared decoder
out_blocks: [
  [-2, 128, 1024, 256, 256, 128, 128, 64, 64, 32],  # image decoder
  [-1, 128, 256, 128, 128, 64, 64, 32, 56],  # anything decoder
  [-1, 128, 256, 128, 128, 64, 64, 32, 32]  # objects decoder
]
traj_count: 6

Prepare the dataset

python prepare_data.py -opts data.yaml

Train the model

python train.py -opts opts.yaml

You can watch the training progress with tensorboard:

tensorboard --logdir <savefolder>/log

Test the model

While testing, you can optionally ban some modalities to test accurate the model reconstructs and forecasts previous and next timesteps.

Without ban:

python test.py -opts opts.yml -banned 0 0 0 -prefix no_ban

Or, ban the objects modality:

python test.py -opts opts.yml -banned 0 0 1 -prefix ban_objects

multimodal-vae's People

Contributors

alper111 avatar codacy-badger avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.