Git Product home page Git Product logo

weakly-supervised-disentangled-representations's Introduction

Weakly-supervised Disentangled Representation Learning

An implementation of Variational Autoencoders(VAEs) based weakly-supervised disentangled representation learning methods in PyTorch.

Supported algorithms

For comparison, I also implement Unsupervised Disentangled Representation Learning methods.

Requirements

  • Python >= 3.6

Installation

git clone https://github.com/koukyo1994/weakly-supervised-disentangled-representations
cd weakly-supervised-disentangled-representations
pip install -r requirements.txt

Data preparation

make prepare

To Run the code

My implementation is a config-based pipeline, and it is easy to use. Everything you need to do is to write a new config file and load it explicitly when you run main.py in the format below:

python main.py --config configs/<your config>.yml

The structure of the config file is below:

globals:
  seed: 1213  # Whatever integer you want

models:
  name: BetaVAE  # Model name. Make sure that model is implemented. You can check `models/__init__.py` to see which model is implemented right now.
  params:  # Model specific parameters. Please see the implementation of each models and check what kind of arguments are required
    input_shape: [1, 64, 64]
    n_latents: 10
    beta: 16.0

dataset:
  name: dsprites_full  # Valid name for `disentanglement_lib/data/ground_truth/named_data/get_named_ground_truth_data`
  type: unsupervised  # Either `unsupervised` or `weak`
  params:  # Arguments for the pytorch dataset in `dataset/pytorch.py`
    iterator_len: 20000

loader:  # Arguments for `torch.utils.data.DataLoader`
  batch_size: 64

optimizer:
  name: Adam  # Name of optimizer. All the optimizers implemented in `torch.optim` can be used.
  params:  # Argument for the optimizer
    lr: 0.0001

training:
  epochs: 1000

logging:
  validate_interval: 100  # Interval between validation. In validation phase, my pipeline output reconstruction image, latent_traversal gif and png, also histogram of latent vectors, and calculate disentanglement metrics. This will take some time so if you set this interval small, the whole calculation takes a lot of time.

weakly-supervised-disentangled-representations's People

Contributors

koukyo1994 avatar

Stargazers

Nathan avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.