Git Product home page Git Product logo

rejax's Introduction

Rejax
Hardware-Accelerated Reinforcement Learning Algorithms in pure Jax!
Open In Colab Code style: black License: Apache 2.0 PyPI version


Rejax is a library of RL algorithms which are implemented in pure Jax. It allows you to accelerate your RL pipelines by using jax.jit, jax.vmap, jax.pmap or any other transformation on whole training algorithms. Use it to quickly search for hyperparameters, evaluate agents for multiple seeds in parallel, or run meta-evolution experiments on your GPUs and TPUs. If you're new to rejax and want to learn more about it,

Open In Colab ๐Ÿ“ธ Take a tour

rejax demo

๐Ÿ— Installing rejax

  1. Install via pip: pip install rejax
  2. Install from source: pip install git+https://github.com/keraJLi/rejax

โšก Vectorize training for incredible speedups!

  • Use jax.jit on the whole train function to run training exclusively on your GPU!
  • Use jax.vmap and jax.pmap on the initial seed or hyperparameters to train a whole batch of agents in parallel!
from rejax.algos import get_agent

# Get train function and initialize config for training
train_fn, config_cls = get_agent("sac")
train_config = config_cls.create(env="CartPole-v1", learning_rate=0.001)

# Jit the training function
jitted_train_fn = jax.jit(train_fn)

# Vmap training function over 300 initial seeds
vmapped_train_fn = jax.vmap(jitted_train_fn, in_axes=(None, 0))

# Train 300 agents!
keys = jax.random.split(jax.random.PRNGKey(0), 300)
train_state, evaluation = vmapped_train_fn(train_config, keys)

Benchmark on an A100 80G and a Intel Xeon 4215R CPU. Note that the hyperparameters were set to the default values of cleanRL, including buffer sizes. Shrinking the buffers can yield additional speedups due to better caching, and enables training of even more agents in parallel.

Speedup over cleanRL on hopper Speedup over cleanRL on breakout

๐Ÿค– Implemented algorithms

Algorithm Link Discrete Continuous Notes
PPO here โœ” โœ”
SAC here โœ” โœ” discrete version as in Christodoulou, 2019
DQN here โœ” incl. DDQN, Dueling DQN
DDPG here โœ”
TD3 here โœ”

๐Ÿ›  Easily extend and modify algorithms

The implementations focus on clarity! Easily modify the implemented algorithms by overwriting isolated parts, such as the loss function, trajectory generation or parameter updates. For example, easily turn DQN into DDQN by writing

class DoubleDQN(DQN):
    @classmethod
    def update(cls, config, state, minibatch):
        # Calculate DDQN-specific targets
        targets = ddqn_targets(config, state, minibatch)

        # The loss function predicts Q-values and returns MSBE
        def loss_fn(params):
            ...
            return jnp.mean((targets - q_values) ** 2)

        # Calculate gradients
        grads = jax.grad(loss_fn)(state.q_ts.params)

        # Update train state
        q_ts = state.q_ts.apply_gradients(grads=grads)
        state = state.replace(q_ts=q_ts)
        return state

๐Ÿ”™ Flexible callbacks

Using callbacks, you can run logging to the console, disk, wandb, and much more. Even when the whole train function is jitted! For example, run a jax.experimental.io_callback regular intervals during training, or print the current policies mean return:

def print_callback(config, state, rng):
    policy = make_act(config, state)         # Get current policy
    episode_returns = evaluate(policy, ...)  # Evaluate it
    jax.debug.print(                         # Print results
        "Step: {}. Mean return: {}",
        state.global_step,
        episode_returns.mean(),
    )
    return ()  # Must return PyTree (None is not a PyTree)

config = config.replace(eval_callback=print_callback)

Callbacks have the signature callback(config, train_state, rng) -> PyTree, which is called every eval_freq training steps with the config and current train state. The output of the callback will be aggregated over training and returned by the train function. The default callback runs a number of episodes in the training environment and returns their length and episodic return, such that the train function returns a training curve.

Importantly, this function is jit-compiled along with the rest of the algorithm. However, you can use one of Jax's callbacks such as jax.experimental.io_callback to implement model checkpoining, logging to wandb, and more, all while maintaining the advantages of a completely jittable training function.

๐Ÿ’ž Alternatives in end-to-end GPU training

Libraries:

  • Brax along with several environments, brax implements PPO and SAC within their environment interface

Single file implementations:

  • PureJaxRL implements PPO, recurrent PPO and DQN
  • Stoix features DQN, DDPG, TD3, SAC, PPO, as well as popular extensions and more

โœ Cite us!

@misc{rejax, 
  title={rejax}, 
  url={https://github.com/keraJLi/rejax}, 
  journal={keraJLi/rejax}, 
  author={Liesen, Jarek and Lu, Chris and Lange, Robert}, 
  year={2024}
} 

rejax's People

Contributors

kerajli avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.