Git Product home page Git Product logo

simpledreamer's Introduction

EasyDreamer: A Simplified Version of the Dreamer Algorithm with Pytorch

Introduction

In this repository, we've implemented a simplified version of the Dreamer algorithm, which is explained in detail in the paper Dream to Control: Learning Behaviors by Latent Imagination. The main goal of Dreamer is to train a model that helps agents perform well in environments with high sample efficiency. We have implemented our version of Dreamer using PyTorch, which simplifies the process and makes the model more accessible to researchers and practitioners who are already familiar with the PyTorch framework. With this implementation, they can gain a deeper understanding of how the algorithm works and test their own ideas more efficiently, contributing to the advancement of research in this field.

We have also included a re-implementation of Plan2Explore, a model-based exploration method introduced in the paper Planning to Explore via Self-Supervised World Models. Plan2Explore is designed to improve generalization about the model without any task-relevant information by using an unsupervised learning approach. Our PyTorch implementation of Plan2Explore is available in this repository.

Differences from other implementations

Our implementation of Dreamer differs from others in several ways. Firstly, we separate the recurrent model from the other models to gain a better understanding of how deterministic processing works. Secondly, we align the naming conventions used in our implementation with those in the paper. Furthermore, modules are trained following the same pseudo code as outlined in the original Dreamer paper. Thirdly, we remove overshooting, which was crucial in Dreamer-v1 and model-based approaches but is no longer mentioned in Dreamer-v2 and v3, and is even omitted from official implementations. Lastly, we use a single-step lambda value calculation, which enhances readability at the expense of performance.


Installation

To install the required dependencies, run the following command:

pip install -r requirements.txt

run

To run the training process, use the following command:

Dreamer

python main.py --config dmc-walker-walk

Plan2Explore

python main.py --config p2e-dmc-walker-walk

Architecture

Dreamer

┌── dreamer
│   ├── algorithms
│   │   └── dreamer.py : Dreamer algorithm. Including the loss function and training loop
│   │   └── plan2explore.py : plan2explore algorithm. Including the loss function and training loop
│   ├── configs
│   │   └ : Contains hyperparameters for the training process and sets up the training environment
│   ├── envs
│   │   ├── envs.py : Defines the environments used in the Dreamer algorithm
│   │   └── wrappers.py : Modifies observations of the environments to make them more suitable for training
│   ├── modules
│   │   ├── actor.py : A linear network to generate action
│   │         └ input : deterministic and stochastic(state)
│   │         └ output : action
│   │   ├── critic.py : A linear network to generate value
│   │         └ input : deterministic and stochastic
│   │         └ output : value
│   │   ├── decoder.py : A convTranspose network to generate reconstructed image
│   │         └ input : deterministic and stochastic
│   │         └ output : reconstructed image
│   │   ├── encoder.py : A convolution network to generate embedded observation
│   │         └ input : image
│   │         └ output : embedded observation
│   │   ├── model.py : Contains the implementation of models
│   │       └ RSSM : Stands for "Recurrent State-Space Model"
│   │         └ RecurrentModel : A recurrent neural network to generate deterministic.
│   │           └ input : stochastic and deterministic and action
│   │           └ output : deterministic
│   │         └ TransitionModel : A linear network to generate stochastic. we call it as prior
│   │           └ input : deterministic
│   │           └ output : stochastic(prior)
│   │         └ RepresentationModel : A linear network to generate stochastic. we call it as posterior.
│   │           └ input : embedded observation and deterministic
│   │           └ output : stochastic(posterior)
│   │       └ RewardModel : A linear network to generate reward
│   │         └ input : deterministic and stochastic 
│   │         └ output : reward
│   │       └ ContinueModel : A linear network to generate continue flag(not done)
│   │         └ input : deterministic and stochastic
│   │         └ output : continue flag
│   │   └── one_step_model.py : A linear network to predict embedded observation # for plan2explore
│   │         └ input : deterministic and stochastic and action
│   │         └ output : embedded observation
│   └── utils
│       ├── buffer.py : Contains the replay buffer used to store and sample transitions during training
│       └── utils.py : Contains other utility functions
└── main.py : Reads the configuration file, sets up the environment, and starts the training process

Todo

  • discrete action space environment performance check
  • code-coverage test
  • dreamer-v2
  • dreamer-v3

Performance

Dreamer

Task 20-EMA
ball-in-cup-catch 936.9
walker-stand 972.8
quadruped-walk 584.7
cheetah-run 694.0
cartpole-balance 831.2
cartpole-swingup-sparse 219.3
finger-turn_easy 805.1
cartpole-balance-sparse 541.6
hopper-hop 250.7
walker-run 284.6
reacher-hard 162.7
reacher-easy 911.4
acrobot-swingup 91.8
finger-spin 543.5
cartpole-swingup 607.8
walker-walk 871.3

All reported results were obtained by running the experiments 3 times with different random seeds. Evaluation was performed after each interaction with the environment, and the reported performance metric is the 20-EMA (exponential moving average) of the cumulative reward in a single episode.


References

simpledreamer's People

Contributors

anthony0727 avatar aoberai avatar seolhokim avatar shougakusei avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

simpledreamer's Issues

Tanh function applied to the output of representation and transition models

RepresentationModel

posterior_dist = create_normal_dist(x, min_std=self.config.min_std)

TransitionModel

prior_dist = create_normal_dist(x, min_std=self.config.min_std)

create_normal_dist

if std == None:

Hi, I have another question about your implementation: why do you apply the tanh function to the output of the representation and transition models?
In the forward methods of both representation and transition model you do not pass the std as actual parameter, so the tanh function is applied to the chunked input x.

In the paper I have not read about any application of the tanh function to the output of the representation and transition models, so I would like to ask you if you have observed some improvements with this choice or if you have used it for other reasons.

I thank you in advance for your helpfulness.

Torch no grad

def environment_interaction(self, env, num_interaction_episodes, train=True):

I would like to ask why torch.no_grad decorator is not used in this method.
I tried running your agent, but the execution would raise the "cuda out of memory" runtime error without the torch.no_grad decorator.

Thank you in advance

Lambda values computation

def compute_lambda_values(rewards, values, continues, horizon_length, device, lambda_):

I would like to ask for an explanation of how you calculate lambda values. I do not understand why you remove the first reward and the first done.

Thank you in advance.

AttrDict is old and appears to be broken on python3.10

Mentions "cannot import name 'Mapping' from 'collections'" is an issue due to python version: https://stackoverflow.com/a/70557518. This issue comes from the python dependency attrdict, which is now old and public archived: https://github.com/bcj/AttrDict. Due to a lack in update from attrdict to resolve this issue, as such, the repo no longer works when running the default repo example. Didn't investigate too deeply so might be an issue on my side.

Sample index bug in replay buffer

    def sample(self, batch_size, chunk_size):
        """
        (batch_size, chunk_size, input_size)
        """
        last_filled_index = self.buffer_index - chunk_size + 1
        assert self.full or (
            last_filled_index > batch_size
        ), "too short dataset or too long chunk_size"
        sample_index = np.random.randint(
            0, self.capacity if self.full else last_filled_index, batch_size
        ).reshape(-1, 1)
        chunk_length = np.arange(chunk_size).reshape(1, -1)

 =>   sample_index = (sample_index + chunk_length) % self.capacity

I suspect there is an error in the line marked by =>, as when the sample_index is close to the end of the buffer_index, the exceeding part will lie at the initialized empty data (which will be cast to zero in the latter). So the correct version would be

      sample_index = (sample_index + chunk_length) % self.buffer_index

[Refactor] pixel_normalization to gym.Wrapper

In cases other than image observation, we might not want out feature be normalized by pixel_normalization

According to common practice in RL code, extract it to gym.Wrapper and preprocess it before inserting into buffer

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.