Git Product home page Git Product logo

gtg's Introduction

GTG

Official Code for Guided Trajectory Generation with Diffusion Models for Offline Model-based Optimization

Environment Setup

To install dependencies, please run commands as follows:

# Create conda environment
conda create -n gtg python=3.8 -y
conda activate gtg

# Mujoco Installation
pip install Cython==0.29.36 numpy==1.22.0 mujoco_py==2.1.2.14
# Mujoco Compile
python -c "import mujoco_py"

# Torch Installation
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117

# Design-Bench Installation
pip install design-bench==2.0.12
pip install robel==0.1.2 morphing_agents==1.5.1 transforms3d --no-dependencies
pip install botorch==0.6.4 gpytorch==1.6.0

# Decision Diffuser Installation
pip install jaynes==0.8.11 ml_logger==0.8.69
pip install gym==0.13.1 params_proto==2.9.6 scikit-image==0.17.2 scikit-video==1.1.11 scikit-learn==0.23.1 typed-argument-parser einops wandb
pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl

# Download Design-Bench Offline Datasets: https://drive.google.com/file/d/11nAyb7_tmlGd0ri5aOK5YP3ZMIYFOmvS/view?usp=drive_link
unzip design_bench_data.zip
rm -rf design_bench_data.zip
mv -v design_bench_data <CONDA_PATH>/envs/gtg/lib/python3.8/site-packages

Code references

Our implementation is based on "Is Conditional Generative Modeling is all you need for Decision Making?" (https://github.com/anuragajay/decision-diffuser)

Main Experiments

You can run the following commands to train and evaluate our method on Design-Bench tasks.

  • Constructing Trajectories: To construct trajectories, you should run the following command.
python construct_trajectories.py --task <task>
  • Training Models: To train models, you should run the following command.
python train.py --task <task> --horizon <horizon> --seed <seed>
  • Evaluate: To sample candidates and do evalaution, you should run the following command.
python evaluate.py --task <task> --horizon <horizon> --ctx_len <ctx_len> --alpha <alpha> --seed <seed>

Additional Experiments

You can run the following commands to train and evaluate our method on practical settings of Design-Bench tasks.

  • Sparse Setting
python construct_trajectories.py --task <task> --frac <frac>

python train.py --task <task> --horizon <horizon> --seed <seed> --frac <frac>

python evaluate.py --task <task> --horizon <horizon> --ctx_len <ctx_len> --alpha <alpha> --seed <seed> --frac <frac>
  • Noisy Setting
python construct_trajectories.py --task <task> --sigma <sigma>

python train.py --task <task> --horizon <horizon> --seed <seed> --sigma <sigma>

python evaluate.py --task <task> --horizon <horizon> --ctx_len <ctx_len> --alpha <alpha> --seed <seed> --sigma <sigma>

gtg's People

Contributors

dbsxodud-11 avatar

Stargazers

xiao-zy19 avatar Han Pengdong avatar sujin yun avatar Sanghyeok Choi avatar  avatar Rongxi Tan avatar Jaewoo Lee avatar

Watchers

Kostas Georgiou avatar  avatar

gtg's Issues

No file named 'environments' in 'diffuser'

I am currently trying to reproduce the algorithm. When I run the code 'python train.py', it comes up with the error:

Traceback (most recent call last):
File "train.py", line 5, in
from scripts.train import main
File "/home/x/GTG/scripts/train.py", line 1, in
import diffuser.utils as utils
File "/home/x/GTG/diffuser/init.py", line 1, in
from . import environments
ImportError: cannot import name 'environments' from 'diffuser' (/home/x/GTG/diffuser/init.py)

What shall I do to reproduce this repository?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.