Git Product home page Git Product logo

rl-exploration-baselines's Introduction

Reinforcement Learning Exploration Baselines (RLeXplore)

RLeXplore is a set of implementations of intrinsic reward driven-exploration approaches in reinforcement learning using PyTorch, which can be deployed in arbitrary algorithms in a plug-and-play manner. In particular, RLeXplore is designed to be well compatible with Stable-Baselines3, providing more stable exploration benchmarks.

Notice

This repo has been merged with a new project: https://github.com/RLE-Foundation/Hsuanwu, in which more reasonable implementations are provided!

Invoke the intrinsic reward module by:

from hsuanwu.xplore.reward import ICM, RIDE, ...

Module List

Module Remark Repr. Visual Reference
PseudoCounts Count-Based exploration ✔️ ✔️ Never Give Up: Learning Directed Exploration Strategies
ICM Curiosity-driven exploration ✔️ ✔️ Curiosity-Driven Exploration by Self-Supervised Prediction
RND Count-based exploration ✔️ Exploration by Random Network Distillation
GIRM Curiosity-driven exploration ✔️ ✔️ Intrinsic Reward Driven Imitation Learning via Generative Model
NGU Memory-based exploration ✔️ ✔️ Never Give Up: Learning Directed Exploration Strategies
RIDE Procedurally-generated environment ✔️ ✔️ RIDE: Rewarding Impact-Driven Exploration for Procedurally-Generated Environments
RE3 Entropy Maximization ✔️ State Entropy Maximization with Random Encoders for Efficient Exploration
RISE Entropy Maximization ✔️ Rényi State Entropy Maximization for Exploration Acceleration in Reinforcement Learning
REVD Divergence Maximization ✔️ Rewarding Episodic Visitation Discrepancy for Exploration in Reinforcement Learning
  • 🐌: Developing.
  • Repr.: The method involves representation learning.
  • Visual: The method works well in visual RL.

Example

Due to the large differences in the calculation of different intrinsic reward methods, Hsuanwu has the following rules:

  1. The environments are assumed to be vectorized;
  2. The compute_irs function of each intrinsic reward module has a mandatory argument samples, which is a dict like:
    • obs (n_steps, n_envs, *obs_shape) <class 'torch.Tensor'>
    • actions (n_steps, n_envs, action_shape) <class 'torch.Tensor'>
    • rewards (n_steps, n_envs) <class 'torch.Tensor'>
    • next_obs (n_steps, n_envs, *obs_shape) <class 'torch.Tensor'>

Take RE3 for instance, it computes the intrinsic reward for each state based on the Euclidean distance between the state and its $k$-nearest neighbor within a mini-batch. Thus it suffices to provide obs data to compute the reward. The following code provides a usage example of RE3:

from hsuanwu.xplore.reward import RE3
from hsuanwu.env import make_dmc_env
import torch as th

if __name__ == '__main__':
    num_envs = 7
    num_steps = 128
    # create env
    env = make_dmc_env(env_id="cartpole_balance", num_envs=num_envs)
    print(env.observation_space, env.action_space)
    # create RE3 instance
    re3 = RE3(
        observation_space=env.observation_space,
        action_space=env.action_space
    )
    # compute intrinsic rewards
    obs = th.rand(size=(num_steps, num_envs, *env.observation_space.shape))
    intrinsic_rewards = re3.compute_irs(samples={'obs': obs})
    
    print(intrinsic_rewards.shape, type(intrinsic_rewards))
    print(intrinsic_rewards)

# Output:
# {'shape': [9, 84, 84]} {'shape': [1], 'type': 'Box', 'range': [-1.0, 1.0]}
# torch.Size([128, 7]) <class 'torch.Tensor'>

rl-exploration-baselines's People

Contributors

cameronberg avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.