Git Product home page Git Product logo

rlax's Introduction

RLax

CI status docs pypi

RLax (pronounced "relax") is a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents. Full documentation can be found at rlax.readthedocs.io.

Installation

You can install the latest released version of RLax from PyPI via:

pip install rlax

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/rlax.git

All RLax code may then be just in time compiled for different hardware (e.g. CPU, GPU, TPU) using jax.jit.

In order to run the examples/ you will also need to clone the repo and install the additional requirements: optax, haiku, and bsuite.

Content

The operations and functions provided are not complete algorithms, but implementations of reinforcement learning specific mathematical operations that are needed when building fully-functional agents capable of learning:

  • Values, including both state and action-values;
  • Values for Non-linear generalizations of the Bellman equations.
  • Return Distributions, aka distributional value functions;
  • General Value Functions, for cumulants other than the main reward;
  • Policies, via policy-gradients in both continuous and discrete action spaces.

The library supports both on-policy and off-policy learning (i.e. learning from data sampled from a policy different from the agent's policy).

See file-level and function-level doc-strings for the documentation of these functions and for references to the papers that introduced and/or used them.

Usage

See examples/ for examples of using some of the functions in RLax to implement a few simple reinforcement learning agents, and demonstrate learning on BSuite's version of the Catch environment (a common unit-test for agent development in the reinforcement learning literature):

Other examples of JAX reinforcement learning agents using rlax can be found in bsuite.

Background

Reinforcement learning studies the problem of a learning system (the agent), which must learn to interact with the universe it is embedded in (the environment).

Agent and environment interact on discrete steps. On each step the agent selects an action, and is provided in return a (partial) snapshot of the state of the environment (the observation), and a scalar feedback signal (the reward).

The behaviour of the agent is characterized by a probability distribution over actions, conditioned on past observations of the environment (the policy). The agents seeks a policy that, from any given step, maximises the discounted cumulative reward that will be collected from that point onwards (the return).

Often the agent policy or the environment dynamics itself are stochastic. In this case the return is a random variable, and the optimal agent's policy is typically more precisely specified as a policy that maximises the expectation of the return (the value), under the agent's and environment's stochasticity.

Reinforcement Learning Algorithms

There are three prototypical families of reinforcement learning algorithms:

  1. those that estimate the value of states and actions, and infer a policy by inspection (e.g. by selecting the action with highest estimated value)
  2. those that learn a model of the environment (capable of predicting the observations and rewards) and infer a policy via planning.
  3. those that parameterize a policy that can be directly executed,

In any case, policies, values or models are just functions. In deep reinforcement learning such functions are represented by a neural network. In this setting, it is common to formulate reinforcement learning updates as differentiable pseudo-loss functions (analogously to (un-)supervised learning). Under automatic differentiation, the original update rule is recovered.

Note however, that in particular, the updates are only valid if the input data is sampled in the correct manner. For example, a policy gradient loss is only valid if the input trajectory is an unbiased sample from the current policy; i.e. the data are on-policy. The library cannot check or enforce such constraints. Links to papers describing how each operation is used are however provided in the functions' doc-strings.

Naming Conventions and Developer Guidelines

We define functions and operations for agents interacting with a single stream of experience. The JAX construct vmap can be used to apply these same functions to batches (e.g. to support replay and parallel data generation).

Many functions consider policies, actions, rewards, values, in consecutive timesteps in order to compute their outputs. In this case the suffix _t and tm1 is often to clarify on which step each input was generated, e.g:

  • q_tm1: the action value in the source state of a transition.
  • a_tm1: the action that was selected in the source state.
  • r_t: the resulting rewards collected in the destination state.
  • discount_t: the discount associated with a transition.
  • q_t: the action values in the destination state.

Extensive testing is provided for each function. All tests should also verify the output of rlax functions when compiled to XLA using jax.jit and when performing batch operations using jax.vmap.

Citing RLax

This repository is part of the DeepMind JAX Ecosystem, to cite Rlax please use the citation:

@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/deepmind},
  year = {2020},
}

rlax's People

Contributors

akssri avatar chris-chris avatar dbudden avatar dependabot[bot] avatar github30 avatar hamzamerzic avatar hawkinsp avatar hbq1 avatar joaogui1 avatar jqdm avatar katebaumli avatar kristianholsheimer avatar mbrukman avatar mtthss avatar rchen152 avatar suryabhupa avatar tomhennigan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rlax's Issues

can not find setup.py for pip install

pip install git+git://github.com/deepmind/rlax.git

pip install is not working.
I think you missed pushing setup.py

$ pip install git+git://github.com/deepmind/rlax.git
Collecting git+git://github.com/deepmind/rlax.git
  Cloning git://github.com/deepmind/rlax.git to /private/var/folders/kv/x78mzwkd087567npnl1l3b200000gn/T/pip-req-build-nspn1ikj
  Running command git clone -q git://github.com/deepmind/rlax.git /private/var/folders/kv/x78mzwkd087567npnl1l3b200000gn/T/pip-req-build-nspn1ikj
    ERROR: Command errored out with exit status 1:
     command: /Users/chris/anaconda3/bin/python -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/private/var/folders/kv/x78mzwkd087567npnl1l3b200000gn/T/pip-req-build-nspn1ikj/setup.py'"'"'; __file__='"'"'/private/var/folders/kv/x78mzwkd087567npnl1l3b200000gn/T/pip-req-build-nspn1ikj/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' egg_info --egg-base /private/var/folders/kv/x78mzwkd087567npnl1l3b200000gn/T/pip-req-build-nspn1ikj/pip-egg-info
         cwd: /private/var/folders/kv/x78mzwkd087567npnl1l3b200000gn/T/pip-req-build-nspn1ikj/
    Complete output (5 lines):
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "/Users/chris/anaconda3/lib/python3.7/tokenize.py", line 447, in open
        buffer = _builtin_open(filename, 'rb')
    FileNotFoundError: [Errno 2] No such file or directory: '/private/var/folders/kv/x78mzwkd087567npnl1l3b200000gn/T/pip-req-build-nspn1ikj/setup.py'
    ----------------------------------------
ERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.

Does discount = 0 mean "terminal" state by design?

I am building the code on which to base my experiments using this library. I noticed that in many rlax functions several episodes can seamlessly be treated as a single episode by appropriately setting zeros in the discount arrays. I.e. I can concatenate two episodes together

# Using notation s_tm1, a_tm1, r_t, discount_t, s_t
# Let s_0, s_1, s_2, s_3, s_4, s_5 be the sequence of states of two episodes,
# where s_2 is terminal state of first episode and s_3 is the first state of the second episode
# The transition s_2 -> r_3 -> discount_3 -> s_4 is therefore "faked" with a discount = 0
values    = [v_1, v_2, v_3, v_4, v_5] 
rewards   = [r_1, r_2, 0.0, r_4, r_5]
discounts = [0.9, 0.9, 0.0, 0.9, 0.9]  # discount corresponding to s_2 -> s_3 is 0 since s_2 is terminal

targets =  n_step_bootstrapped_returns(rewards, discounts, values, n=2)

I checked several functions and this seems to be the case. Is this the case in all functions by design? If yes, will it continue to remain like this? I'm asking for confirmation as I might base some design choices on this very nice property of rlax.

Thanks for the beautiful and useful library!

Implementation for Squashed Gaussian

Hi! I have questions regarding the implementation for squashed gaussian distribution in distribution.py

  1. In https://github.com/deepmind/rlax/blob/b7d1a012f888d1744245732a2bcf15f38bb7511e/rlax/_src/distributions.py#L319, should it be "-=" instead of the current "+="? because we have to take log of the inverse of determinant?

  2. I am having trouble understanding a particular function in the implementation of the squashed gaussian in:
    https://github.com/deepmind/rlax/blob/b7d1a012f888d1744245732a2bcf15f38bb7511e/rlax/_src/distributions.py#L295

In particular, I dont understand why the the mean (mu) and standard deviation (sigma) of the
original Gaussian random variable (pre-squash) have to be "activated". The mean (mu) seems to go through a "tanh" transformation and variable sigma goes through a really odd-looking transformation on line 295 and 299. Why do we have to need sigma_min(default to -4) and sigma_max(default to 0)? What is the meaning of these transformations on mu and sigma?

  1. My understanding of the implementation is that instead of "squashing" a gaussian random variable (call it X) into (-1,1), we are squashing X into specific action_specs from the environment. Therefore, the transformation/squashing function is no long just Y = tanh(X) as in the original SAC paper but rather Y = g(X) = scale * (tanh(X) + 1) +min_vals and this is confirmed by the function "transform": https://github.com/deepmind/rlax/blob/b7d1a012f888d1744245732a2bcf15f38bb7511e/rlax/_src/distributions.py#L302. I went through the math to derive the pdf and logpdf of this new transformed/squashed random variable Y, and my result is nothing like the code. Did i miss something in the code? I suspect that this is related to question 2. Any comment and reference pointers would be much appreciated!

Documentation and Examples

So, are you guys planning on releasing the documentation and adding more examples? It's kind of hard to understand what the library has to offer and how to use it right now

Writing a MPO example (help I'm confused)

I'm trying to write an example for MPO (for a categorical action space).
However, I'm confused.

Mainly I'm confused about the kl_constraints arg to mpo_loss.

kl_constraints = [(rlax.categorical_kl_divergence(???, ???), lagrange_penalty)]

I dont understand what the two args to the kl div would be.
(also. I dont understand why it's a list. How can there be more than one kl div?)


Afaik, this KL constraint is to be used for the M step.
So should be doing something like;

$$ J(\theta) = ... + KL(π(a|s, θ_i), π(a|s, θ)) $$

However, this equation also doesnt make sense to me.
Arent we evaluating the gradient of $J$ at $\theta_i$, so the KL term would be 0?

What am I missing...? (something important it seems.)

PopArt example bug

Hi,

It is mentioned in the comments here that new popart states should be used to normalize/denormalize, but in the code old states are being used.

Thanks
Kinal

requirements and examples folder should not be packaged.

Hi, given the current setup.py, the requirements and example folder are packaged in the wheel. I don't think this is intended and this has caused some issues for downstream usage. I created a PR with a fix, could you kindly take a look at #132 ?

Install error with jaxlib version 0.1.47

rlax gives the following error when I try to install:

Collecting jaxlib>=0.1.37 (from rlax==0.0.1)
Could not find a version that satisfies the requirement jaxlib>=0.1.37 (from rlax==0.0.1) (from versions: 0.1, 0.1.1, 0.1.2, 0.1.3, 0.1.4, 0.1.5, 0.1.6, 0.1.7, 0.1.8, 0.1.9, 0.1.11, 0.1.12, 0.1.13, 0.1.14, 0.1.15, 0.1.16, 0.1.17, 0.1.18, 0.1.19, 0.1.20, 0.1.21, 0.1.22, 0.1.23)
No matching distribution found for jaxlib>=0.1.37 (from rlax==0.0.1)

Release new version to loose jax version constraints

Hi,

Thanks for the very useful library.

I had some issues using rlax with other most recent version of the DM JAX ecosystem. Specifically, the current release version pins jax<=0.2.21 which is incompatible with latest version of dm-haiku (v0.0.6). Is it possible to release a new version that includes this commit 58b3672?

Best,
YL

Support for Bandits?

Thanks for the library! I'm a huge fan of it. I know that RL isn't the same as bandits, but would DM ever consider integrating in various bandits into the rlax library?

Questions re: RLax Value Learning ?

Hi! I have several questions/requests regarding value learning https://github.com/deepmind/rlax/blob/master/rlax/_src/value_learning.py

  1. If I want to use the _quantile_regression_loss without the Huber aspect, does setting huber_param equal to 0 accomplish this? That's my understanding, but I'd like to check :)

  2. I'm interested in exploring expectile regression-naive DQN and expectile regression DQN, but code for these two related algorithms don't seem to exist. Is that correct? If code does exist, could you point me in the right direction?

  3. If functions for expectile regression indeed do not exist, what would be the most straightforward way to implement them? If I just want expectile regression-naive, I'm thinking I would need to do the following:

a. Copy _quantile_regression_loss() to create _expectile_regression_loss(), replacing the quantile loss with expectile loss
b. Copy quantile_q_learning() to create expectile_q_learning, replacing the _quantile_regression_loss() call with a _expectile_regression_loss() call

Is this correct? If so, would you be open to PRs?

  1. Expectile regression is a little trickier, due to its imputation strategy. Are you planning on implementing & releasing that? If not, how would you recommend implementing that?

Relax numpy version requirement

1e2e4bc introduced version qualifier numpy < 1.23 due to failing tests with jax, but I believe this has been resolved with recent jax versions and we no longer need < 1.23.

I could write a trivial PR that fixes requirements.txt, but let me leave it to be handled by the rlax team because the commit does not specify at which jax version tests are failing, hopefully the CL can update jax and numpy version requirements.

/cc @katebaumli

Stop action gradient in policy gradient loss

The current implementation of policy_gradient_loss is:

log_pi_a_t = distributions.softmax().logprob(a_t, logits_t)
adv_t = jax.lax.select(use_stop_gradient, jax.lax.stop_gradient(adv_t), adv_t)
loss_per_timestep = -log_pi_a_t * adv_t

It's good that the gradients are already stopped around the advantages, but they should also be stopped around the actions to ensure an unbiased gradient estimator.

This is important when the actions are sampled as part of the training graph (MPO-style algos, imagination training with world models) rather than coming from the replay buffer, and the actor distribution implements a gradient for sample() (e.g. gaussian, or straight-through categoricals).

Question of `logprob_fn` in the `squashed_gaussian`?

Hello, I have two questions about the logprob_fn in squashed_gaussian.

  1. For a squashed multivariate Gaussian distribution:
    image

It seems there is a missing of 0.5 in half_logdet = jnp.sum(jnp.log(sigma), axis=-1)

https://github.com/deepmind/rlax/blob/87392e93a1da2fda240034930e801daf54590218/rlax/_src/distributions.py#L308

  1. When computing the log_det_jacobian:

image

We need use inv_transform(sample, action_spec) to convert a to u.

https://github.com/deepmind/rlax/blob/87392e93a1da2fda240034930e801daf54590218/rlax/_src/distributions.py#L271

tree_multimap deprecated in favor of tree_map

I get this error when trying to import rlax, I think because tree_multimap has been deprecated?

ImportError: cannot import name 'tree_multimap' from 'jax.tree_util' (/home/rohanmehta/anaconda3/envs/py39/lib/python3.9/site-packages/jax/tree_util.py)

Make documentation download-able as PDF?

Is there a reason it's not? I'd like to download it to my tablet reader (a low-powered eink device) and be able to reference it from there. The web browsing feature (even if it's local) leaves a lot to be desired on some of those devices

vtrace uses `lax.scan`?

Hi,

I'm wondering is there any fundamental reason the implementation of vtrace uses python for loop rather than lax.scan?
The reason I ask is that it appears that currently vtrace compilation takes significantly longer time than a scan version, which also doesn't seem to give a huge performance benefit?
Is backward pass memory usage the issue here?

My scan based implementation: https://colab.research.google.com/drive/1lnxLNSse90MG8WuUu251KUMCzT0k2H4c#scrollTo=x3oxiZ9ZI5Vs

Benchmark numbers (in seconds):

rlax compile: 22.384824752807617
rlax: 38.763017416000366
scan compile: 0.5877189636230469
scan: 38.86183452606201

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.