Git Product home page Git Product logo

outer-value-function-meta-rl's Introduction

Debiasing Meta-Gradient Reinforcement Learning by Learning the Outer Value Function

This repository contains the code from the paper Debiasing Meta-Gradient Reinforcement Learning by Learning the Outer Value Function (Clément Bonnet, Laurence Midgley, Alexandre Laterre) published at the 6th Workshop on Meta-Learning at NeurIPS 2022, New Orleans.

Abstract

Meta-gradient Reinforcement Learning (RL) allows agents to self-tune their hyper-parameters in an online fashion during training. In this paper, we identify a bias in the meta-gradient of current meta-gradient RL approaches. This bias comes from using the critic that is trained using the meta-learned discount factor for the advantage estimation in the outer objective which requires a different discount factor. Because the meta-learned discount factor is typically lower than the one used in the outer objective, the resulting bias can cause the meta-gradient to favor myopic policies. We propose a simple solution to this issue: we eliminate this bias by using an alternative, outer value function in the estimation of the outer loss. To obtain this outer value function we add a second head to the critic network and train it alongside the classic critic, using the outer loss discount factor. On an illustrative toy problem, we show that the bias can cause catastrophic failure of current meta-gradient RL approaches, and show that our proposed solution fixes it. We then apply our method to a more complex environment and demonstrate that fixing the meta-gradient bias can significantly improve performance.

Experiments

We denote:

  • A2C: Advantage Actor Critic
  • MG: meta-gradient algorithm from [Xu et al., 2018]
  • BMG: bootstrapped meta-gradients from [Flennerhag et al., 2022]
  • MG outer-critic: the MG algorithm equipped with an outer-critic that estimates the outer value function used in the outer loss
  • BMG outer-critic: the BMG algorithm similarly equipped with an outer-critic

Discounting Chain

Discounting Chain Return Discounting Chain Discount Factor

Snake

Snake Return Snake Discount Factor

Reproducibility

We provide a requirements.txt file with all the tagged packages needed to reproduce the experiments in the paper. Here is a snippet of commands to set up a virtual environment and install these packages. Alternatively, one can use a Conda environment or a similar solution.

python -m venv venv
source venv/bin/activate
pip install -U pip setuptools wheel
pip install -r requirements.txt

Discounting Chain

The Discounting Chain environment is originally from bsuite but is imported from Gymnax in this paper to benefit from its JAX implementation.

To reproduce the experiments on the Discounting Chain environment, you can run the jupyter notebook discounting_chain_train.ipynb. The other notebook discounting_chain.ipynb loads the data and plots the figures from the paper.

Snake

The Snake environment is provided by Jumanji in JAX.

To reproduce the experiments on the Snake environment, one can run the following commands.

  • Advantage Actor Critic (A2C)
python snake_train.py -m agent=a2c training.seed=1,2,3,4,5,6,7,8,9,10
  • Meta-Gradient (MG)
python snake_train.py -m agent=mgrl agent.outer_critic=false training.seed=1,2,3,4,5,6,7,8,9,10
python snake_train.py -m agent=mgrl agent.outer_critic=true training.seed=1,2,3,4,5,6,7,8,9,10
  • Bootstrapped Meta-Gradient (BMG)
python snake_train.py -m agent=bootstrap agent.outer_critic=false training.seed=1,2,3,4,5,6,7,8,9,10
python snake_train.py -m agent=bootstrap agent.outer_critic=true training.seed=1,2,3,4,5,6,7,8,9,10
  • appendix
python snake_train.py -m agent=a2c agent.outer_critic=false agent.normalize_advantage=false,true agent.normalize_outer_advantage=false training.seed=1
python snake_train.py -m agent=mgrl agent.outer_critic=false,true agent.normalize_advantage=false,true agent.normalize_outer_advantage=false,true training.seed=1
python snake_train.py -m agent=bootstrap agent.outer_critic=false,true agent.normalize_advantage=false,true agent.normalize_outer_advantage=false,true training.seed=1

Note that the default logger is "terminal". If you want to save the data, a Neptune logger is implemented, and you can enable it by replacing "terminal" with "neptune" in snake_train.py.

For the paper, the data from these runs was collected and uploaded in snake/data/. The snake.ipynb notebook loads this data and make the plots from the paper.

Citation

For attribution in academic contexts, please use the following citation.

@misc{bonnet2022debiasing,
  title = {Debiasing Meta-Gradient Reinforcement Learning by Learning the Outer Value Function},
  author = {Bonnet, Clément and Midgley, Laurence and Laterre, Alexandre},
  doi = {10.48550/ARXIV.2211.10550},
  url = {https://arxiv.org/abs/2211.10550},
  year = {2022},
  booktitle={Sixth Workshop on Meta-Learning at the Conference on Neural Information Processing Systems},
}

outer-value-function-meta-rl's People

Contributors

clement-bonnet avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Forkers

fuyw labchameleon

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.