Git Product home page Git Product logo

coin's Introduction

CoIN

This repository is the official implementation of our ACL'24 Findings paper Contrastive Instruction Tuning.

Installation

Dependency

Experiments are run in the following environment:

Package Version
conda 22.9.0
Python 3.8
CUDA 11.8

Install via Conda and Pip

conda create -n coin python=3.8
conda activate coin
pip install -r requirements.txt

Data

The original data source of our new dataset is the FLAN collection, specifically from Muennighoff/flan on Huggingface. The model we used is Alpaca trained with LoRA. We used code from Alpaca-LoRA as a starting point and added our implementations. We follow the steps discussed in section 3.2 of the paper to curate dataset for CoIN, which is available here.

  • Each entry contains:
    • The original instruction-input pair (original_instruction)
    • The paraphrased instruction-input pair (paraphrased_instruction)
    • Label (targets)
    • Task name
    • Keyword data (a dictionary that contains KV pairs that will be parsed into the instruction templates to get the full input).
  • Instruction templates are available here.
  • Every entry at the odd index is the hard negative for the entry above it.

Training

Parameters are defined in run_contrastive.sh. Check ContrastiveLlamaTrainingArgument in run_contrastive_llama.py for more details regarding default values of all parameters.

  • To start training the CoIN model, please run the following:
    bash scripts/run_contrastive.sh
    
  • To run the continually instruction-tuned model (training with data augmentation only), change do_contrastive to FALSE.

Evaluation

Run Evaluation on Unseen Instructions

In this project, we follow PromptBench to add perturbations to instructions. All perturbed instructions for 10 GLUE tasks are available here. To evaluate a model, please:

  • Go to eval_contrastive.sh
  • Change checkpoint_dir to the path of your checkpoint/output directory
  • Run:
    bash scripts/eval_contrastive.sh
    
  • You can change perturb_method and promptbench_eval_task to evaluate the model on different perturbation methods and evaluation tasks. Supported perturbation methods and tasks are available in the bash script and UnseenInstructionEvalArgs in run_contrastive_llama.py.

Postprocessing of Evaluation Results

To obtain average accuracy(exact match) and standard deviation of the model on the perturbed instructions for each task, please run:

python promptbench/postprocessing.py --output_dir "YOUR_OUTPUT_DIR"
  • The evaluation script will store model's outputs to the directory named preds under your model's checkpoint directory.
  • Substitute YOUR_OUTPUT_DIR with the path where the outputs are stored (e.g. output/CoIN/preds).
  • The script will produce a csv file named unseen_instruction_acc.csv under YOUR_OUTPUT_DIR.

Citation

@inproceedings{yan2024contrastive,
  title={Contrastive Instruction Tuning},
  author={Yan, Tianyi and Wang, Fei and Huang, James Y and Zhou, Wenxuan and Yin, Fan and Galstyan, Aram and Yin, Wenpeng and Chen, Muhao},
  booktitle={ACL - Findings},
  year={2024}
}

coin's People

Contributors

lorenayannnnn avatar muhaochen avatar

Stargazers

 avatar  avatar Yibo avatar Yinghui Li avatar Fei Wang avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.