Git Product home page Git Product logo

torch-ac's Introduction

PyTorch Actor-Critic deep reinforcement learning algorithms: A2C and PPO

The torch_ac package contains the PyTorch implementation of two Actor-Critic deep reinforcement learning algorithms:

Note: An example of use of this package is given in the rl-starter-files repository. More details below.

Features

  • Recurrent policies
  • Reward shaping
  • Handle observation spaces that are tensors or dict of tensors
  • Handle discrete action spaces
  • Observation preprocessing
  • Multiprocessing
  • CUDA

Installation

pip3 install torch-ac

Note: If you want to modify torch-ac algorithms, you will need to rather install a cloned version, i.e.:

git clone https://github.com/lcswillems/torch-ac.git
cd torch-ac
pip3 install -e .

Package components overview

A brief overview of the components of the package:

  • torch_ac.A2CAlgo and torch_ac.PPOAlgo classes for A2C and PPO algorithms
  • torch_ac.ACModel and torch_ac.RecurrentACModel abstract classes for non-recurrent and recurrent actor-critic models
  • torch_ac.DictList class for making dictionnaries of lists list-indexable and hence batch-friendly

Package components details

Here are detailled the most important components of the package.

torch_ac.A2CAlgo and torch_ac.PPOAlgo have 2 methods:

  • __init__ that may take, among the other parameters:
    • an acmodel actor-critic model, i.e. an instance of a class inheriting from either torch_ac.ACModel or torch_ac.RecurrentACModel.
    • a preprocess_obss function that transforms a list of observations into a list-indexable object X (e.g. a PyTorch tensor). The default preprocess_obss function converts observations into a PyTorch tensor.
    • a reshape_reward function that takes into parameter an observation obs, the action action taken, the reward reward received and the terminal status done and returns a new reward. By default, the reward is not reshaped.
    • a recurrence number to specify over how many timesteps gradient is backpropagated. This number is only taken into account if a recurrent model is used and must divide the num_frames_per_agent parameter and, for PPO, the batch_size parameter.
  • update_parameters that first collects experiences, then update the parameters and finally returns logs.

torch_ac.ACModel has 2 abstract methods:

  • __init__ that takes into parameter an observation_space and an action_space.
  • forward that takes into parameter N preprocessed observations obs and returns a PyTorch distribution dist and a tensor of values value. The tensor of values must be of size N, not N x 1.

torch_ac.RecurrentACModel has 3 abstract methods:

  • __init__ that takes into parameter the same parameters than torch_ac.ACModel.
  • forward that takes into parameter the same parameters than torch_ac.ACModel along with a tensor of N memories memory of size N x M where M is the size of a memory. It returns the same thing than torch_ac.ACModel plus a tensor of N memories memory.
  • memory_size that returns the size M of a memory.

Note: The preprocess_obss function must return a list-indexable object (e.g. a PyTorch tensor). If your observations are dictionnaries, your preprocess_obss function may first convert a list of dictionnaries into a dictionnary of lists and then make it list-indexable using the torch_ac.DictList class as follow:

>>> d = DictList({"a": [[1, 2], [3, 4]], "b": [[5], [6]]})
>>> d.a
[[1, 2], [3, 4]]
>>> d[0]
DictList({"a": [1, 2], "b": [5]})

Note: if you use a RNN, you will need to set batch_first to True.

Examples

Examples of use of the package components are given in the rl-starter-scripts repository.

Example of use of torch_ac.A2CAlgo and torch_ac.PPOAlgo

...

algo = torch_ac.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_lambda,
                        args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
                        args.optim_eps, args.clip_eps, args.epochs, args.batch_size, preprocess_obss)

...

exps, logs1 = algo.collect_experiences()
logs2 = algo.update_parameters(exps)

More details here.

Example of use of torch_ac.DictList

torch_ac.DictList({
    "image": preprocess_images([obs["image"] for obs in obss], device=device),
    "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
})

More details here.

Example of implementation of torch_ac.RecurrentACModel

class ACModel(nn.Module, torch_ac.RecurrentACModel):
    ...

    def forward(self, obs, memory):
        ...

        return dist, value, memory

More details here.

Examples of preprocess_obss functions

More details here.

torch-ac's People

Contributors

lcswillems avatar saleml avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

torch-ac's Issues

Small bug in algos/base.py

Hi, I think line 127 and line 167 both use 'i' as the enumerating index, but they form a nested loop. This can cause incorrect index referencing once 'done' is not full of 'False'.

Support for Off-Policy Actor-Critic Algorithms like ACER

Hi,

Is there any support for the off-policy counterpart of A2C (ACER algorithm) that can be made based on this repo?

This is a very useful repo that we mostly use, and also nice to have its compatibility with gym minigrid tasks. However, there is no open source implementation of ACER that can be made compatible with gym minigrid yet.

Would be nice to have ACER support coming along with this repo?

There is a useful one I found, which works for ALE tasks, but no support for gym minigrid.
https://github.com/belepi93/pytorch-acer

Support for Multi-agent?

Hi Lucas,

Thanks for your sharing with your project. It is really a nice platform for us to do the benchmark. I wanna to know is there any possible way to implement adding the multi-agent option? Thanks!

ParallelEnv class yields non-correct rewards in a minigrid environment

I tried to use the parallelenv class for creating parallel episodes. I used this minigrid environment: https://github.com/maximecb/gym-minigrid/blob/master/README.md (with MiniGrid-Empty-5x5-v0) The rewards should be (1 - c*time_taken_toreachgreen) (where c is a constant), but it seems when I use the parallelenv , rewards do not follow this. I am actually observing that the rewards increase with time.
Example: Say we have 10 step episodes. Normally we should be observing this type of rewards:
[0, 0, 0.95, 0, 0, 0.9, 0, 0, 0.85, 0]
(this is a list where the first element is the reward obtained at t=0, second element is the reward at t=1, and so on. )
But, I am observing rewards like this with ParallelEnv():
[0, 0, 0.95, 0, 0, 0.95, 0, 0, 0.95, 0], or even increasing rewards like the following :
[0, 0, 0.85, 0, 0, 0.90, 0, 0, 0.95, 0]

I might be misunderstanding the purpose of the ParallelEnv class: My understanding was that it is supposed to give totally independent episodes, and it shouldn't disrupt the original reward structure? It would be great if you could let me know how I could fix this. Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.