Git Product home page Git Product logo

apo's Introduction

Adversarial Preference Optimization

Code License Data License Python 3.8+

This repo contains the implementation of Adversarial Preference Optimization (APO).

We let the reward model (RM) and LLM agent play a min-max game, through which both models can be further enhanced without additional preference annotation.

Currently, the repo contains:

We are continuously updating this repo for the reproduction of APO experiments.

Data & Annotation

To separately update RM and LLM, we split the cleaned Helpful&Harmless (HH) dataset into an RM training set and a LLM training set.

Data Type HH-RM Train Set HH-LLM Train Set HH Test Set
Preference Pairs RM training set RM validation set RM testing set
Golden Answers APO positive responses - -
User Queries APO negative responses (Alpaca samples) LLM (Alpaca) rejection samples LLM testing Queries

Environment

We use Python3.8 with the dependencies listed in requirements.txt. To build the appropriate environment, use the following command:

pip3 install -r requirements.txt

Base RM Training

To train the base RM for rejection sampling, use the following command:

REPO_DIR=<path_to_this_repo>
DATA_DIR=${REPO_DIR}/data/hh-split
TRAIN_DATA_LIST="${DATA_DIR}/rm_data/hh_split_rm.train.json"
TEST_DATA_LIST="${DATA_DIR}/eval_data/hh_cleaned_origin.test.json\
		${DATA_DIR}/eval_data/hh_split_llm.valid.json"
		
NUM_GPUS=8
BATCH_SIZE=64
MICRO_BATCH_SIZE=1
LEARNING_RATE=1e-6
GRADIENT_ACCUMULATION_STEP=$((BATCH_SIZE / NUM_GPUS / MICRO_BATCH_SIZE))

torchrun --nproc_per_node=${NUM_GPUS} --master_port=6000 ${REPO_DIR}/train.py \
    --task_type hh_split \
    --do_train True \
    --eval_at_start False \
    --model_type reward \
    --model_name_or_path <path_to_llama_7b_checkpoint_and_tokenizer> \
    --data_type comparison_pair \
    --train_data_path ${TRAIN_DATA_LIST} \
    --eval_data_path ${TEST_DATA_LIST} \
    --data_suffix rm_base \
    --add_sep_token True \
    --remove_unused_columns false \
    --output_dir <path_to_save_your_RM_checkpoint> \
    --num_train_epochs 1 \
    --per_device_train_batch_size ${MICRO_BATCH_SIZE} \
    --per_device_eval_batch_size ${MICRO_BATCH_SIZE} \
    --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEP} \
    --evaluation_strategy steps \
    --padding_side right \
    --truncation_side left \
    --pooling_type last \
    --max_length 512 \
    --save_strategy steps \
    --save_total_limit 10 \
    --learning_rate ${LEARNING_RATE} \
    --warmup_steps 100 \
    --logging_steps 10 \
    --eval_steps 50 \
    --weight_decay 0. \
    --deepspeed configs/default_offload_opt_param.json \
    --tf32 false --fp16 false

We also trained a testing RM to evaluate the LLM response samples on the testing queries automatically. To train the testing RM, change TRAIN_DATA_LIST=${DATA_DIR}/hh_cleaned_origin.train.json in the above command to learn with all the HH training comparisons.

APO RM Training

To train the APO RM, first merge LLM samples and golden annotations into APO comparision pairs:

REPO_DIR=<path_to_this_repo>
DATA_DIR="${REPO_DIR}/data/hh-split"

python3 ${REPO_DIR}/tools/apo_data_converter.py \
	--golden_data_path ${DATA_DIR}/rm_data/hh_split_rm.golden.json \
	--sample_data_path ${DATA_DIR}/rm_data/hh_split_rm_alpaca_v0.sample.json \
	--output_dir ${DATA_DIR}/apo_data \
	--apo_data_name "rm_apo_data_v0"

Then use the following command to conduct APO finetuning for Base RM:

REPO_DIR=<path_to_this_repo>
DATA_DIR=${REPO_DIR}/data/hh-split
TRAIN_DATA_LIST="${DATA_DIR}/rm_data/hh_split_rm.train.json \
		 ${DATA_DIR}/apo_data/rm_apo_data_v0_text_scores.json"
NUM_APO_SAMPLES=4

TEST_DATA_LIST="${DATA_DIR}/eval_data/hh_cleaned_origin.test.json \
		${DATA_DIR}/eval_data/hh_split_llm.valid.json"
		
NUM_GPUS=8
BATCH_SIZE=64
MICRO_BATCH_SIZE=1
LEARNING_RATE=1e-7
APO_COEFF=0.1
GRADIENT_ACCUMULATION_STEP=$((BATCH_SIZE / NUM_GPUS / MICRO_BATCH_SIZE))


torchrun --nproc_per_node=${NUM_GPUS} --master_port=6000 ${REPO_DIR}/train.py \
    --task_type hh_split \
    --do_train True \
    --eval_at_start False \
    --model_type reward \
    --model_name_or_path <path_to_RM_Base_checkpoint> \
    --data_type comparison_pair \
    --train_data_path ${TRAIN_DATA_LIST} \
    --eval_data_path ${TEST_DATA_LIST} \
    --data_suffix rm_apo_v1 \
    --add_sep_token True \
    --remove_unused_columns false \
    --output_dir <path_to_save_your_APO_RM_checkpoint> \
    --num_train_epochs 1 \
    --apo_loss_coeff ${APO_COEFF} \
    --apo_sample_num ${NUM_APO_SAMPLES} \
    --per_device_train_batch_size ${MICRO_BATCH_SIZE} \
    --per_device_eval_batch_size ${MICRO_BATCH_SIZE} \
    --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEP} \
    --evaluation_strategy steps \
    --padding_side right \
    --truncation_side left \
    --pooling_type last \
    --max_length 512 \
    --save_strategy steps \
    --save_total_limit 10 \
    --learning_rate ${LEARNING_RATE} \
    --warmup_steps 100 \
    --logging_steps 10 \
    --eval_steps 50 \
    --weight_decay 0. \
    --deepspeed configs/default_offload_opt_param.json \
    --tf32 false --fp16 false

Citation

@article{cheng2023adversarial,
  title={Adversarial Preference Optimization},
  author={Cheng, Pengyu and Yang, Yifan and Li, Jian and Dai, Yong and Du, Nan},
  journal={arXiv preprint arXiv:2311.08045},
  year={2023}
}

apo's People

Contributors

linear95 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.