Git Product home page Git Product logo

storm's Introduction

Implementation of STORM: Efficient Stochastic Transformer basedWorld Models for Reinforcement Learning

This repo contains an implementation of STORM.

Following the Training and Evaluating Instructions to reproduce the main results presented in our paper. One may also find Additional Useful Information useful when debugging and observing intermediate results. To reproduce the speed metrics mentioned in the paper, please see Reproducing Speed Metrics.

Training and Evaluating Instructions

  1. Install the necessary dependencies. Note that we conducted our experiments using python 3.10.

    pip install -r requirements.txt

    Installing AutoROM.accept-rom-license may take several minutes.

  2. Train the agent.

    chmod +x train.sh
    ./train.sh

    The train.sh file controls the environment and the running name of a training process.

    env_name=MsPacman
    python -u train.py \
        -n "${env_name}-life_done-wm_2L512D8H-100k-seed1" \
        -seed 1 \
        -config_path "config_files/STORM.yaml" \
        -env_name "ALE/${env_name}-v5" \
        -trajectory_path "trajectory/${env_name}.pkl"
    • The env_name on the first line can be any Atari game, which can be found here.

    • -n option is the name for the tensorboard logger and checkpoint folder. You can change it to your preference, but we recommend keeping the environment's name first. The tensorboard logging folder is runs, and the checkpoint folder is ckpt.

    • The -seed parameter controls the running seed during the training. We evaluated our method using 5 seeds and report the mean return in Table 1.

    • The -config_path points to a YAML file that controls the model's hyperparameters. The configuration in config_files/STORM.yaml is the same as in our paper.

    • -config_path leads to a yaml file controlling the model's hyperparameters. The configuration in config_files/STORM.yaml is the same with our paper.

    • The -trajectory_path is only useful when the option UseDemonstration in the YAML file is set to True (by default it's False). This corresponds to the ablation studies in Section 5.3. We provide the precollected trajectories in the D_TRAJ.7z file, and you need to decompress it for using (to a D_TRAJ folder).

  3. Evaluate the agent. The evaluation results will be presented in a CSV file located in the eval_result folder.

    chmod +x eval.sh
    ./eval.sh

    The eval.sh file controls the environment and the running name when testing an agent.

    env_name=MsPacman
    python -u eval.py \
        -env_name "ALE/${env_name}-v5" \
        -run_name "${env_name}-life_done-wm_2L512D8H-100k-seed1"\
        -config_path "config_files/STORM.yaml" 

    The -run_name option is the same as the -n option in train.sh. It should be kept the same as in the training script.

Additional Useful Information

You can use Tensorboard to visualize the training curve and the imagination videos:

 chmod +x TensorBoard.sh
 ./TensorBoard.sh

Reproducing Speed Metrics

To reproduce the speed metrics mentioned in the paper, please consider the following:

  • Hardware requirements: NVIDIA GeForce RTX 3090 with a high frequence CPU, we use 11th Gen Intel(R) Core(TM) i9-11900K in our experiments. Low frequence CPUs may lead to a GPU idle and slow down the traning. To make full use of a powerful GPU, one can traing several agents at the same time on one device.
  • Software requiements: PyTorch>=2.0.0 is required.

We also tested our code on other devices and identified some possible troubleshooting steps:

  • Our experiments used bfloat16 to accelerate training. To train on devices that do not support bfloat16, such as the NVIDIA V100, you need to change torch.bfloat16 to torch.float16 in both agents.py and sub_models/world_models.py. Additionally, modify the line attn = attn.masked_fill(mask == 0, -1e9) to attn = attn.masked_fill(mask == 0, -6e4) to prevent overflow.
  • On devices like the NVIDIA A100, using bfloat16 may slow down the training. In this case, you can toggle the self.use_amp = True option in both agents.py and sub_models/world_models.py.

Code references

We've referenced several other projects during the development of this code:

Bibtex

@inproceedings{
zhang2023storm,
title={{STORM}: Efficient Stochastic Transformer based World Models for Reinforcement Learning},
author={Weipu Zhang and Gang Wang and Jian Sun and Yetian Yuan and Gao Huang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=WxnrX42rnS}
}

storm's People

Contributors

cmeo97 avatar weipu-zhang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.