Git Product home page Git Product logo

dex's Introduction

DEX: Demonstration-Guided RL with Efficient Exploration for Task Automation of Surgical Robot

This is the official PyTorch implementation of the paper "Demonstration-Guided Reinforcement Learning with Efficient Exploration for Task Automation of Surgical Robot" (ICRA 2023).

Prerequisites

  • Ubuntu 18.04
  • Python 3.7+

Installation Instructions

  1. Clone this repository.
git clone --recursive https://github.com/med-air/DEX.git
cd DEX
  1. Create a virtual environment
conda create -n dex python=3.8
conda activate dex
  1. Install packages
pip3 install -e SurRoL/	# install surrol environments
pip3 install -r requirements.txt
pip3 install -e .
  1. Then add one line of code at the top of gym/gym/envs/__init__.py to register SurRoL tasks:
# directory: anaconda3/envs/dex/lib/python3.8/site-packages/
import surrol.gym

Usage

Commands for DEX and all baselines. Results will be logged to WandB. Before running the commands below, please change the wandb entity in train.yaml to match your account.

We collect demonstration data via the scripted controllers provided by SurRoL. Take the NeedlePick task as example:

mkdir SurRoL/surrol/data/demo
python SurRoL/surrol/data/data_generation.py --env NeedlePick-v0 

Training Commands

  • Train DEX:
python3 train.py task=NeedlePick-v0 agent=dex use_wb=True
  • Train SAC:
python3 train.py task=NeedlePick-v0 agent=sac use_wb=True
  • Train DDPG:
python3 train.py task=NeedlePick-v0 agent=ddpg use_wb=True
  • Train DDPGBC:
python3 train.py task=NeedlePick-v0 agent=ddpgbc use_wb=True
  • Train CoL:
python3 train.py task=NeedlePick-v0 agent=col use_wb=True
  • Train AMP:
python3 train.py task=NeedlePick-v0 agent=amp use_wb=True
  • Train AWAC:
python3 train.py task=NeedlePick-v0 agent=awac use_wb=True
  • Train SQIL:
python3 train.py task=NeedlePick-v0 agent=sqil use_wb=True

Again, all commands can be run on other surgical tasks by replacing NeedlePick with the respective environment in the commands (for both demo collection and RL training).

We also implement synchronous parallelization of RL training, e.g., launch 4 parallel training processes:

mpirun -np 4 python -m train agent=dex task=NeedlePick-v0 use_wb=True

It should be noted that parallel training will lead to inconsistent performance, which require hyperparameters tuning.

Evaluation Commands

We also provide a script for evaluate the saved model. The directory of the to-be-evaluated model should be included in the configuration file eval.yaml, where the checkpoint is specified by ckpt_episode. For instance:

  • Eval model trained by DEX in NeedlePick-v0:
python3 eval.py task=NeedlePick-v0 agent=dex ckpt_episode=latest

Starting to Modify the Code

Modifying the hyperparameters

The default hyperparameters are defined in dex/configs, where train.yaml defines the experiment settings and YAML file in the directory agent defines the hyperparameters of each method. Modifications to these parameters can be directly defined in the experiment or agent config files, or passed through the terminal command. For example:

python3 train.py task=NeedleRegrasp-v0 agent=dex use_wb=True batch_size=256 agent.aux_weight=10

Adding a new RL algorithm

The core RL algorithms are implemented within the BaseAgent class. For adding a new algorithm, a new file needs to be created in dex/agents and BaseAgent needs to be subclassed. In particular, any required networks (actor, critic etc) need to be constructed and the update(...) function and get_action(...) needs to be overwritten. For an example, see the DDPGBC implementation in DDPGBC. When implementation is done, a registration is needed in factory.py and a config file should also be made in agent to specify the model parameters.

Transfering to other simulation platform

Our code is designed for standard goal-conditioned gym-based environments and can be easily transfered to other platform if provide the same interfaces (e.g., OpenAI gym fetch). If no similar interface is provided, some modifications should be made to make it compatible, e.g., replay buffer and sampling utilities. We will make our code more generalizable in the future.

Code Navigation

dex
  |- agents                # implements core algorithms in agent classes
  |- components            # reusable infrastructure for model training
  |    |- checkpointer.py  # handles saving + loading of model checkpoints
  |    |- normalizer.py    # normalizer for vectorized input
  |    |- logger.py        # implements core logging functionality using wandB
  |
  |- configs               # experiment configs 
  |    |- train.yaml       # configs for rl training
  |    |- eval.yaml        # configs for rl evaluation
  |    |- agent            # configs for each algorithm (dex, ddpg, ddpgbc, etc.)
  |
  |- modules               # reusable architecture components
  |    |- critic.py        # basic critic implementations (eg MLP-based critic)
  |    |- distributions.py # pytorch distribution utils for density model
  |    |- policy.py    	   # basic actor implementations
  |    |- replay_buffer.py # her replay buffer with future sampling strategy
  |    |- sampler.py       # rollout sampler for collecting experience
  |    |- subnetworks.py   # basic networks
  |
  |- trainers              # main model training script, builds all components + runs training loop and logging
  |
  |- utils                 # general and rl utilities, pytorch / visualization utilities etc
  |- train.py              # experiment launcher
  |- eval.py               # evaluation launcher

Contact

For any questions, please feel free to email [email protected].

dex's People

Contributors

taohuang13 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

dex's Issues

where is the training data file (.npz)?

Dear Sir
Thanks for previous answering the EGL problem. It is EGL problem, and I finally using a GPU card, to bypass this problem.
After trying using the python3 train.py task=NeedlePick-v0 agent=dex use_wb=False, I could find the .pth files, but I cannot .npz files. Where could I find it, for python SurRoL/surrol/data/data_generation.py --env NeedlePick-v0
Another question are:
It will jump broken after training for few hours, and showed "pybullet.error: Not connected to physics server". And when I try re-do the training, it clean all the previous training .pth files. When the training is broken, how could I re-training?
Thanks a lot.

Detail parameters in Expert Demonstration Data

May I know what are the meaning of each value in the demonstration data?

>>> import numpy as np
>>> data = np.load('data_BiPegTransfer-v0_random_100.npz', allow_pickle=True)
>>> data['obs'][0][0]['observation']
array([ 2.50000024e+00,  2.50000328e-01,  3.77600002e+00, -1.57077582e+00,
       -7.86881280e-06,  1.57089384e+00, -2.68686934e-08,  2.50000024e+00,
       -2.50000358e-01,  3.77600002e+00, -1.57077499e+00,  8.22652205e-06,
        1.57069500e+00,  2.63462330e-08,  2.81499942e+00, -1.35001428e-01,
        3.52300466e+00,  3.14999185e-01, -3.85001756e-01, -2.52995362e-01,
        3.14999185e-01,  1.14998930e-01, -2.52995362e-01,  2.83223295e+00,
       -1.16891325e-01,  3.52300572e+00,  1.95795628e-05,  3.74380987e-05,
       -2.33138680e+00,  2.82206702e+00, -1.58981025e-01,  3.52300453e+00,
       -4.22121162e-05, -1.76266571e-06,  1.85740287e+00])
>>> data['obs'][0][0]['observation'].shape
(35, )

The observation's size is 35 for the BiPegTransfer task, what are those values represented?

Thank you in advance!

wandb problem

I have created a wandb account and changed the wandb entity in train.yaml before running the command,but it appears the error below.Is there anything else should be modified?

屏幕截图 2024-07-21 104512

Adjusting x and y Angles

Hello,

I find your approach to work impressively well. I am currently trying to extend the needle pick to a starting position where the needle is not lying flat on a surface but is partially inserted into the tissue. For now, this can't work because in training the sampler only generates episodes with rotation around x- and y-axis being zero. If I change this I assume that I also have to change some hyperparameters in train.yaml. Do you have any suggestions on which ones would be necessary to change?

Transform Needle Pose to World Frame

Hello,

I managed to export the trajectory and transform it from the PyBullet world frame to the RCM frame. I also managed to set the needle poses, but I have one final question.

When setting the needle position, it seems like I need to do this in the PyBullet frame. However, since I want to use the trajectory in a different simulation environment, I have to transform the desired pose of the needle to the PyBullet world frame first. I'm unsure how to do this transformation. I know how to transform my desired needle pose to the RCM frame though, but I don't know how to get to the PyBullet frame from there. The world2rcm function appears to work only for the robot arm.
I would appreciate any suggestions.

Coordinate Transformation and Workspace

Hello,

I am receiving a sensible cartesian trajectory for the needle pick.
untitled
However, the cartesian position doesn't quite match the joint positions. For example providing the dVRK with the joint positions [0.41041827508953993 -0.5074103527953661 0.16461946887696666 0.48424603099967994 0.6265858251749896 -0.1466937918716302] I would expect to receive a position close to x=0.0312652, y=0.0035137, z=0.11708. The learned trajectory however gives me x= 2.6260259151458745, y=0.01377844344824547, z= 3.5077011585235596.

I assume this is due to the pyBullet coordinate frame. Could you maybe provide some details on how to transform this to the robot coordinate system?

Additionally, the area in which the needle can be picked is confined by the workspace limits 2.5<x<3, -0.25<y<0.25, 3.426<z<3.776. Is it possible to extend this so that it would be possible to pick a needle with z=0.0?

I would really appreciate your help.

eval error

after the first turn of training loop, I got the error, and I notice the program stoped when body=1 and link 13, the program always stop here File "/media/u/all_code/SurRol2/SurRoL/rl/train.py", line 9, in main exp.train() File "/media/u/all_code/SurRol2/SurRoL/rl/trainers/rl_trainer.py", line 109, in train score = self.eval() File "/media/u/all_code/SurRol2/SurRoL/rl/trainers/rl_trainer.py", line 182, in eval episode, _, env_steps = self.eval_sampler.sample_episode(is_train=False, render=True) File "/media/u/all_code/SurRol2/SurRoL/rl/modules/samplers.py", line 25, in sample_episode self.init() File "/media/u/all_code/SurRol2/SurRoL/rl/modules/samplers.py", line 18, in init self._episode_reset() File "/media/u/all_code/SurRol2/SurRoL/rl/modules/samplers.py", line 55, in _episode_reset self._obs = self._reset_env() File "/media/u/all_code/SurRol2/SurRoL/rl/modules/samplers.py", line 59, in _reset_env return self._env.reset() File "/home/u/anaconda3/envs/surrol3/lib/python3.7/site-packages/gym/wrappers/time_limit.py", line 26, in reset return self.env.reset(**kwargs) File "/home/u/anaconda3/envs/surrol3/lib/python3.7/site-packages/gym/wrappers/order_enforcing.py", line 18, in reset return self.env.reset(**kwargs) File "/media/u/all_code/SurRol2/SurRoL/surrol/gym/surrol_goalenv.py", line 19, in reset return super().reset() File "/media/u/all_code/SurRol2/SurRoL/surrol/gym/surrol_env.py", line 147, in reset self._sample_goal_callback() File "/media/u/all_code/SurRol2/SurRoL/surrol/tasks/needle_pick.py", line 104, in _sample_goal_callback pos_obj, orn_obj = get_link_pose(self.obj_id, self.obj_link1) File "/media/u/all_code/SurRol2/SurRoL/surrol/utils/pybullet_utils.py", line 725, in get_link_pose link_state = get_link_state(body, link) File "/media/u/all_code/SurRol2/SurRoL/surrol/utils/pybullet_utils.py", line 707, in get_link_state return LinkState(*p.getLinkState(body, link))

How to Apply Learned Policy

Hello,
I didn't have any problems running the code. However, after successfully running the training algorithms (e.g. DDPG), I don't know how to actually apply the learned policy to get the endeffector trajectory for the NeedlePick Task. By that I mean not just running the evaluation script and getting, for example, the success rate, but really retrieving output that I can use for my robot.
I apologize if this question is too basic, but I would be thankful for every hint.

How to use the trained model?

I have followed your steps for training, but how should I get test data and use the trained model for testing? Looking forward to your recovery

failed to EGL with glad.

Dear Professor
I had installed as web, and the demo of "python SurRoL/surrol/data/data_generation.py --env NeedlePick-v0" running OK. But it failed for the Training Commands, showing "failed to EGL with glad." Is there necessary for GPU installed? (I have no GPU). Or, is there anything I miss?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.