Git Product home page Git Product logo

policy-refactorization's Introduction

Refactoring Policy for Compositional Generalizabilityusing Self-Supervised Object Proposals

This repository contains the official code of the NeurIPS 2020 paper Refactoring Policy for Compositional Generalizabilityusing Self-Supervised Object Proposals.

Installation

conda env create -f environment.yml

Experiment pipeline of FallingDigit

Note

  • Our method is a two-stage framework. The first stage is to train an RL teacher policy and collect demonstration dataset, and the second stage is to train a self-supervised object detector and GNN-based student policy on this demonstration dataset. Please follow the below steps one by one.
  • We provide some example configuration files in configs directory, and the most outputs the experiments will be saved in output directory.
  • The format of game environment name is FallingDigit${bg}_${n}-v1, where ${bg} can be Black or CIFAR, ${n} is the number of target digits (from 3 to 9). And we also provides the test environment FallingDigit${bg}_${n}_test-v1, which contains different game levels compared to the training environment (each level is generated by an unique random seed). Since we have different game environments with different backgournds, make sure you use environments with the same background through the whole experiment pipeline.
  • Since training a teacher policy by RL takes some time, we provide two trained teacher policies (trained on FallingDigitBlack_3-v1, trained on FallingDigitCIFAR_3-v1). With them, you can skip start from the step of collecting demostration datset.

Train a teacher policy by DQN

python dqn/main.py --cfg configs/falling_digit_rl/dqn_relation_net.yml env FallingDigitCIFAR_3-v1

The checkpoints directory will be something like ./outputs/falling_digit_rl/dqn_relation_net/11-09_22-16-09_FallingDigitCIFAR_3-v1 .

Select a teacher checkpoint and collect demonstration dataset

python tools/select_teacher_checkpoint.py --env FallingDigitCIFAR_3-v1 \
    --cfg configs/falling_digit_rl/dqn_relation_net_eval.yml \
    --ckpt-dir ${YOUR_RL_OUTPUT_DIR}
python tools/collect_demo_dataset_for_falling_digit.py --env FallingDigitCIFAR_3-v1 \
    --cfg configs/falling_digit_rl/dqn_relation_net_eval.yml \
    --ckpt ${THE_SELECTED_GOOD_TEACHER_CHECKPOINT_PATH}

The collected demostration dataset will be saved in data directory.

Train a self-supervised object detector and generate object proposals for demo dataset

Paste the path of the collected demostration dataset into configs/falling_digit_space/cifar_space_v1.yaml. Specificially, paste into DATASET.TRAIN.path and DATASET.VAL.path, we use different splits of the same dataset as training set and validation set. Then run the following commands.

python space/train_space.py --cfg configs/falling_digit_space/cifar_space_v1.yaml
python space/predict_space.py --cfg configs/falling_digit_space/cifar_space_v1.yaml

Train a GNN-based student policy

Similarly, paste the path of the collected demostration dataset into configs/falling_digit_refactor/cifar_gnn.yaml. Then run the following command.

python refactorization/train_gnn.py --cfg configs/falling_digit_refactor/cifar_gnn.yaml

Test a GNN-based student policy

python tools/eval_student_policy.py \
    --env FallingDigitCIFAR_9_test-v1 \
    --n-episode 100 \
    gnn \
    --detector-model SPACE_v1 \
    --detector-checkpoint ${YOUR_DETECTOR_CHECKPOINT_PATH} \
    --gnn-model EdgeConvNet \
    --gnn-checkpoint ${YOUR_GNN_POLICY_CHECKPOINT_PATH}

${YOUR_DETECTOR_CHECKPOINT_PATH} should be something like outputs/falling_digit_space/cifar_space_v1/model_060000.pth, ${YOUR_GNN_POLICY_CHECKPOINT_PATH} should be like outputs/falling_digit_refactor/cifar_gnn/model_best.pth. Note that the test environment FallingDigit${bg}_${n}_test-v1 should be used here.

Citation

If you find our paper useful in an academic setting, please cite:

@article{mu2020refactoring,
  title={Refactoring Policy for Compositional Generalizability using Self-Supervised Object Proposals},
  author={Mu, Tongzhou and Gu, Jiayuan and Jia, Zhiwei and Tang, Hao and Su, Hao},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}

Acknowledgments

The self-supervised object detector part in this implementation refers to some details in Zhixuan Lin's original implementaion. The reinforcement learning part in this implementation is adapted from Shaotong Zhang's DeepRL code base.

policy-refactorization's People

Contributors

jiayuan-gu avatar tongzhoumu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.