Git Product home page Git Product logo

vsg's Introduction

Learning Robust Dynamics Through Variational Sparse Gating

Learning world models from their sensory inputs enables agents to plan for actions by imagining their future outcomes. World models have previously been shown to improve sample-efficiency in simulated environments with few objects, but have not yet been applied successfully to environments with many objects. In environments with many objects, often only a small number of them are moving or interacting at the same time. In this paper, we investigate integrating this inductive bias of sparse interactions into the latent dynamics of world models trained from pixels. First, we introduce Variational Sparse Gating (VSG), a latent dynamics model that updates its feature dimensions sparsely through stochastic binary gates. Moreover, we propose a simplified architecture Simple Variational Sparse Gating (SVSG) that removes the deterministic pathway of previous models, resulting in a fully stochastic transition function that leverages the VSG mechanism. We evaluate the two model architectures in the BringBackShapes (BBS) environment that features a large number of moving objects and partial observability, demonstrating clear improvements over prior models.

Setup

The dependencies can be installed using the requirements.txt file:

cd VSG
virtualenv --no-download VSG
source VSG/bin/activate
pip install --upgrade pip
pip install tensorflow==2.4.1 tensorflow_probability==0.12.2
pip install -r requirements.txt

NOTE: In case there are issues with numpy, specifically NotImplementedError: Cannot convert a symbolic Tensor (strided_slice:0) to a numpy array., follow the fix mentioned here.

BringBackShapes

To conduct experiments on the proposed BringBackShapes environment, first install the environment following the instructions here

bash scripts/bringbackshapes.sh {MODEL} {SUFFIX} sparse 3000 {DISTRACTORS} 5 False False False 125 {SIZE} {GATE_PRIOR} {SEED}
Variable Description
DISTRACTORS Number of stochastic distractors in the env
SIZE Scale of the area to control partial observability, 1.0 refers to the Basic version
MODEL Name of the agent
SUFFIX Name of the experiment for logging
GATE_PRIOR Gate prior probabilities for sparse gating mechanism in VSG/SVSG
SEED Seed parameter

An example run is

bash scripts/bringbackshapes.sh VSG baseline sparse 3000 0 5 False False False 125 1.0 0.4 1

DeepMind Control Suite

For running experiments on tasks from DeepMind Control Suite, first install the dm_control repository following the instructions here. Then use the command below to run with the appropriate config.

Variable Description
TASK DMC Task to train on
MODEL Choose from DreamerV1, DreamerV2, VSG or SVSG
SUFFIX Name of the experiment for logging
GATE PRIOR Prior gate probability for VSG or SVSG
DIM Size of the latent state
SEED Seed parameter
bash scripts/dmc.sh {TASK} {MODEL} {SUFFIX} {GATE_PRIOR} {DIM} {SEED}

For example to run VSG on walker_walk,

bash scripts/dmc.sh walker_walk VSG baseline 0.4 1024 1

Bibtex

If you find this code useful, please reference in your paper:

@InProceedings{Jain22,
    author    = "Jain, Arnav Kumar and Sujit, Shivakanth and Joshi, Shruti and Michalski, Vincent and Hafner, Danijar and Kahou, Samira Ebrahimi",
    title     = "Learning Robust Dynamics through Variational Sparse Gating",
    booktitle = {Advances in Neural Information Processing Systems},
    month     = {December},
    year      = {2022}
  }

Acknowledgements

This code was developed using DreamerV2.

vsg's People

Contributors

arnavkj1995 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

ipsec

vsg's Issues

A '' warning'' appears after I run the program, which cannot be removed

I tried everything, but I still couldn't get rid of it.
The ''warning'' message as follow:
WARNING:root:
Distribution subclass SafeTruncatedNormal inherits _parameter_properties from its parent (TruncatedNormal) while also redefining init. The inherited annotations cover the following parameters: dict_keys(['loc', 'scale', 'low', 'high']). It is likely that these do not match the subclass parameters. This may lead to errors when computing batch shapes, slicing into batch dimensions, calling .copy(), flattening the distribution as a CompositeTensor (e.g., when it is passed or returned from a tf.function), and possibly other cases. The recommended pattern for distribution subclasses is to define a new _parameter_propertiesmethod with the subclass parameters, and to store the corresponding parameter values asself._parametersininit`, after
calling the superclass constructor:

                          ```
                          class MySubclass(tfd.SomeDistribution):
                          
                            def __init__(self, param_a, param_b):
                              parameters = dict(locals())
                              # ... do subclass initialization ...
                              super(MySubclass, self).__init__(**base_class_params)
                              # Ensure that the subclass (not base class) parameters are stored.
                              self._parameters = parameters
                          
                            def _parameter_properties(self, dtype, num_classes=None):
                              return dict(
                                # Annotations may optionally specify properties, such as `event_ndims`,
                                # `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
                                # the `ParameterProperties` documentation for details.
                                param_a=tfp.util.ParameterProperties(),
                                param_b=tfp.util.ParameterProperties())
                          ```

Although it does not affect the normal operation of the program, I still want to remove it. I would be grateful if there is a way to help.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.