Git Product home page Git Product logo

pada's Introduction

PADA

PADA is an example-based prompt generation model, which adapts on-the-fly to unseen domains (or distributions in general). It is trained on labeled data from multiple domains, and when presented with new examples (from unknown domains), it performs an autoregressive inference: (1) First generating an example-specific signature that maps the input example to the semantic space spanned by its training domains (denoted as DRFs); and then (2) it casts the generated signature as a prompt (prefix) and performs the downstream task.

PADA

If you use this code please cite our paper (see recommended citation below).

Our code is implemented in PyTorch, using the Transformers and PyTorch-Lightning libraries.

Usage Instructions

Before diving into our running example of how to run PADA, make sure your virtual environment includes all requirements (specified in 'pada_env.yml').

We ran our experiments on a single NVIDIA Quadro RTX 6000 24GB GPU, CUDA 11.1 and PyTorch 1.7.1.

0. Setup a conda environment

You can run the following command to create a conda environment from our .yml file:

conda env create --file pada_env.yml
conda activate pada

Training a PADA Model

You can run all the steps below for all our experiments for a given task (Rumor Detection rumor or Aspect Prediction absa) with a single command, by running the run-rumor-train-experiments.sh script:

bash run-rumor-train-experiments.sh

Running a single experiment with PADA consists of the following steps:

  1. Define an experimental setup - Choose a single target domain of a given task (e.g., charliehebdo from 'Rumor Detection') and its corresponding source domains (ferguson, germanwings-crash, ottawashooting, sydneysiege).
  2. Extract the DRF sets for each of the source domains.
  3. Annotate training examples with DRF-based prompts.
  4. Run PADA - train PADA on the prompt-annotated training set and test it on the target domain test set.

Next, we go through these steps using our running example:

  • Task - Rumor Detection.
  • Source domains - ferguson, germanwings-crash, ottawashooting, sydneysiege.
  • Target domain - charliehebdo We use a specific set of hyperparameters (please refer to our paper for more details).

1. Define an experimental setup

GPU_ID=<ID of GPU>
PYTHONPATH=<path to repository root>
TOKENIZERS_PARALLELISM=false

TASK_NAME=rumor
ROOT_DATA_DIR=${TASK_NAME}_data
SRC_DOMAINS=ferguson,germanwings-crash,ottawashooting,sydneysiege
TRG_DOMAIN=charliehebdo

TRAIN_BATCH_SIZE=32
EVAL_BATCH_SIZE=32
NUM_EPOCHS=5
ALPHA_VAL=0.2

2. Extract DRF sets

Then run the following command to extract a DRF set for each of the source domains.

python ./src/utils/drf_extraction.py \
--domains ${SRC_DOMAINS} \
--dtype ${TASK_NAME} \
--drf_set_location ./runs/${TASK_NAME}/${TRG_DOMAIN}/drf_sets

This will save 4 files, each named by '<SRC_DOMAIN_NAME>.pkl', in the following directory: './runs/<TASK_NAME>/<TRG_DOMAIN>/drf_sets'.

3. Annotate training examples with DRF-based prompts

python ./src/utils/prompt_annotation.py \
    --domains ${SRC_DOMAINS} \
    --root_data_dir ${ROOT_DATA_DIR} \
    --drf_set_location ./runs/${TASK_NAME}/${TRG_DOMAIN}/drf_sets \
    --prompts_data_dir ./runs/${TASK_NAME}/${TRG_DOMAIN}/prompt_annotations

For each source domain, this code creates a file with annotated prompt per each of its training example. The file is placed in the following path: './runs/<TASK_NAME>/<TRG_DOMAIN>/prompt_annotations/<SRC_DOMAIN_NAME>/annotated_prompts_train.pt'. ** model hyperparameters grid for this step are specified in the paper.

4. Training PADA

Train PADA both on the prompt-generation task and the downstream task (conditioned on the annotated-prompts). Then, evaluate PADA on data from the target domain, where for each example it first generates a prompt and then it performs the downstream task conditioned on its self generated prompt.

CUDA_VISIBLE_DEVICES=${GPU_ID} python ./train.py \
  --dataset_name ${TASK_NAME} \
  --src_domains ${SRC_DOMAINS} \
  --trg_domain ${TRG_DOMAIN} \
  --num_train_epochs ${NUM_EPOCHS} \
  --train_batch_size ${TRAIN_BATCH_SIZE} \
  --eval_batch_size ${EVAL_BATCH_SIZE} \
  --mixture_alpha ${ALPHA_VAL}

The final results are saved in the following path: "./runs/<TASK_NAME>/<TRG_DOMAIN>/PADA/e<NUM_EPOCHS>/b<TRAIN_BATCH_SIZE>/a<ALPHA_VAL>/test_results.txt". For rumor detection and aspect prediction, we report the final binary-F1 score on the target domain, denoted as 'test_binary_f1'.

Evaluating Trained PADA Models

You can evaluate checkpoints for all our experiments for a given task (Rumor Detection rumor or Aspect Prediction absa), by downloading the model files from here, extracting them to a designated directory CKPT_PATH and running the run-rumor-eval-checkpoints.sh script:

bash run-rumor-eval-checkpoints.sh

Evaluating a single trained PADA model checkpoint consists of the following steps:

  1. Create an experimental setup - Choose a single target domain of a given task (e.g., charliehebdo from 'Rumor Detection') and its corresponding source domains (ferguson, germanwings-crash, ottawashooting, sydneysiege).
  2. Download the model files from here and extract them to a designated directory CKPT_PATH.
  3. Run PADA - evaluate the trained PADA model checkpoint on its target domain test set.

Next, we go through these steps using our running example:

  • Task - Rumor Detection.
  • Source domains - ferguson, germanwings-crash, ottawashooting, sydneysiege.
  • Target domain - charliehebdo We use a specific set of hyperparameters (please refer to our paper for more details).

1. Create an experimental setup

GPU_ID=<ID of GPU>
PYTHONPATH=<path to repository root>
TOKENIZERS_PARALLELISM=false

TASK_NAME=rumor
ROOT_DATA_DIR=${TASK_NAME}_data
SRC_DOMAINS=ferguson,germanwings-crash,ottawashooting,sydneysiege
TRG_DOMAIN=charliehebdo

EVAL_BATCH_SIZE=<desired batch size>
CKPT_PATH=<path to model files>

2. Evaluate PADA on target domain data

Evaluate a trained PADA model checkpoint on data from the target domain. For each example, PADA first generates a prompt and then it performs the downstream task conditioned on its self generated prompt.

CUDA_VISIBLE_DEVICES=${GPU_ID} python ./eval.py \
  --dataset_name ${TASK_NAME} \
  --src_domains ${SRC_DOMAINS} \
  --trg_domain ${TRG_DOMAIN} \
  --eval_batch_size ${EVAL_BATCH_SIZE} \
  --ckpt_path ${CKPT_PATH}

The final results are saved in the following path: "./runs/<TASK_NAME>/<TRG_DOMAIN>/PADA/eval-ckpt/test_results.txt". For rumor detection and aspect prediction, we report the final binary-F1 score on the target domain (of the best performing model on the source dev data), denoted as 'test_binary_f1'.

How to Cite PADA

@article{DBLP:journals/corr/abs-2102-12206,
  author    = {Eyal Ben{-}David and
               Nadav Oved and
               Roi Reichart},
  title     = {{PADA:} Example-based Prompt Learning for on-the-fly Adaptation to Unseen Domains},
  journal   = {CoRR},
  volume    = {abs/2102.12206},
  year      = {2021},
  url       = {https://arxiv.org/abs/2102.12206},
  eprinttype = {arXiv},
  eprint    = {2102.12206},
  timestamp = {Tue, 02 Mar 2021 12:11:01 +0100},
  biburl    = {https://dblp.org/rec/journals/corr/abs-2102-12206.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}

pada's People

Contributors

eyalbd2 avatar nadavo avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.