Git Product home page Git Product logo

leq's Introduction

Tackling Long-Horizon Tasks with Model-based Offline Reinforcement Learning

This repository contains the official implementation of Tackling Long-Horizon Tasks with Model-based Offline Reinforcement Learning by Kwanyoung Park and Youngwoon Lee.

If you use this code for your research, please consider citing our paper:

@article{park2024tackling,
  title={Tackling Long-Horizon Tasks with Model-based Offline Reinforcement Learning},
  author={Kwanyoung Park and Youngwoon Lee},
  journal={arXiv Preprint arxiv:2407.00699},
  year={2024}
}

How to run the code

Install dependencies

conda create -n LEQ python=3.9
conda activate LEQ

pip install -r requirements.txt

# Install jax (https://github.com/google/jax#pip-installation-gpu-cuda)
pip install jax[cuda]==0.4.8 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Install glew & others
conda install -c conda-forge glew
conda install -c conda-forge mesalib
conda install -c menpo glfw3
export CPATH=$CONDA_PREFIX/include
pip install patchelf

# Recover versions
pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl
pip install numpy==1.23.0
pip install scipy==1.10.1

Also, see other configurations for CUDA here.

Pretrain world model

For training the world model, we use the training script of OfflineRL-Kit.

For convenience, we provide run_dynamics.py that can be utilized to train the model with OfflineRL-Kit.

cd ..
git clone https://github.com/yihaosun1124/OfflineRL-Kit
cd OfflineRL-Kit
python setup.py install
cp ../LEQ/dynamics/run_dynamics.py run_example/run_dynamics.py
cp -r ../LEQ/d4rl_ext .

Now, you can train the model with the run_dynamics.py. For example, you can run the command as below:

python run_example/run_dynamics.py --task antmaze-medium-replay-v2 --seed 3

Run training

LEQ

PYTHONPATH='.' python train/train_LEQ.py --env_name=walker2d-medium-replay-v2 --expectile 0.5

MOBILEQ (Please refer to the ablation study section of the paper for details)

PYTHONPATH='.' python train/train_MOBILEQ.py --env_name=Hopper-v3-medium --beta 1.0

MOBILE (Jax implementation of Sun et al.)

PYTHONPATH='.' python train/train_MOBILE.py --env_name=antmaze-large-play-v2 --beta 1.0

References

  • The implementation is based on IQL.
  • MOBILE implementation is from OfflineRLKit.

leq's People

Contributors

kwanyoungpark avatar youngwoon avatar

Stargazers

Jake Hyun avatar Gurumurthi V Ramanan avatar  avatar Shyam Sudhakaran avatar

Watchers

 avatar  avatar Kostas Georgiou avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.