Git Product home page Git Product logo

text-autoaugment's Introduction

Text-AutoAugment (TAA)

This repository contains the code for our paper Text AutoAugment: Learning Compositional Augmentation Policy for Text Classification (EMNLP 2021 main conference).

Overview of IAIS

Updates

  • [21.10.27]: We make taa installable as a package and adapt to huggingface/transformers. Now you can search augmentation policy for your custom dataset with TWO lines of code.

Quick Links

Overview

  1. We present a learnable and compositional framework for data augmentation. Our proposed algorithm automatically searches for the optimal compositional policy, which improves the diversity and quality of augmented samples.

  2. In low-resource and class-imbalanced regimes of six benchmark datasets, TAA significantly improves the generalization ability of deep neural networks like BERT and effectively boosts text classification performance.

Getting Started

Prepare environment

Install pytorch and other small additional dependencies. Then, install this repo as a python package. Note that cudatoolkit=10.0 should match the CUDA version on your machine.

# Clone this repo
git clone https://github.com/lancopku/text-autoaugment.git
cd text-autoaugment

# Create a conda environment
conda -n taa python=3.6
conda activate taa

# Install dependencies
conda install pytorch cudatoolkit=10.0 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
pip install git+https://github.com/wbaek/theconf
pip install git+https://github.com/ildoonet/pystopwatch2.git
pip install -r requirements.txt

# Install this library (**no need to re-build if the source code is modified**)
python setup.py develop

# Download the models in NLTK
python -c "import nltk; nltk.download('wordnet'); nltk.download('averaged_perceptron_tagger')"

Use TAA with Huggingface

Get augmented training dataset with TAA policy

Search for the optimal policy

You can search for the optimal policy on classification datasets supported by huggingface/datasets:

from taa.search_and_augment import search_and_augment

# return the augmented train dataset in the form of torch.utils.data.Dataset
augmented_train_dataset = search_and_augment(configfile="/path/to/your/config.yaml")

The configfile contains some preset arguments, including:

  • model:

    • type: backbone model
  • dataset:

    • path: Path or name of the dataset
    • name: Defining the name of the dataset configuration
    • data_dir: Defining the data_dir of the dataset configuration
    • data_files: Path(s) to source data file(s)

    All the augments above are used for the load_dataset() function in huggingface/datasets. Please refer to link for details.

    • text_key: Used to get text from a data instance (dict form in huggingface/datasets. See this IMDB example.)
  • abspath: Your working directory

  • aug: Pre-searched policy. Now we support IMDB, SST5, TREC, YELP2 and YELP5. See archive.py.

  • per_device_train_batch_size: Batch size per device for training

  • per_device_eval_batch_size: Batch size per device for evaluation

  • epoch: Training epoch

  • lr: Learning rate

  • max_seq_length

  • n_aug: Augment each text sample n_aug times

  • num_op: Number of operations per sub-policy

  • num_policy: Number of sub-policy per policy

  • method: Search method (taa)

  • topN: Ensemble topN sub-policy to get final policy

  • ir: Imbalance rate

  • seed: Random seed

  • trail: Trail under current random seed

  • train:

    • npc: Number of examples per class in the training dataset
  • valid:

    • npc: Number of examples per class in the val dataset
  • test:

    • npc: Number of examples per class in the test dataset
  • num_search: Number of optimization iteration

  • num_gpus: Number of GPUs used in RAY

  • num_cpus: Number of CPUs used in RAY

ATTENTION: bert_sst2_example.yaml is a config example for BERT model and SST2 dataset. You can follow this example to create your own config file. For instance, if you only want to change the dataset from sst2 to imdb, just delete the sst2 in the 'path' argument, modify the 'name' to imdb and modity the 'text_key' to text.

WARNING: The policy optimization framework is based on ray. By default we use 4 GPUs and 40 CPUs for policy optimization. Make sure your computing resources meet this condition, or you will need to create a new configuration file. And please specify the gpus, e.g., CUDA_VISIBLE_DEVICES=0,1,2,3 before using the above code. TPU does not seem to be supported now.

Use our pre-searched policy

To train a model on the datasets augmented by our pre-searched policy, please use (Take IMDB as an example):

from taa.search_and_augment import augment_with_presearched_policy

# return the augmented train dataset in the form of torch.utils.data.Dataset
augmented_train_dataset = augment_with_presearched_policy(configfile="/path/to/your/config.yaml")

Now we support IMDB, SST5, TREC, YELP2 and YELP5. See archive.py for details.

This table lists the test accuracy (%) of pre-searched TAA policy on full datasets:

Dataset IMDB SST-5 TREC YELP-2 YELP-5
No Aug 88.77 52.29 96.40 95.85 65.55
TAA 89.37 52.55 97.07 96.04 65.73
n_aug 4 4 4 2 2

More pre-searched policies and their performance will be COMING SOON.

Fine-tune a new model on the augmented training dataset

After getting augmented_train_dataset, you can load it to the huggingface trainer directly. Please refer to search_augment_train.py for details.

Reproduce results in the paper

Please see examples/reproduce_experiment.py, and run script/huggingface_lowresource.sh or script/huggingface_imbalanced.sh.

Contact

If you have any questions related to the code or the paper, feel free to open an issue or email Shuhuai (renshuhuai007 [AT] gmail [DOT] com).

Acknowledgments

Code refers to: fast-autoaugment.

Citation

If you find this code useful for your research, please consider citing:

@inproceedings{ren2021taa,
  title={Text AutoAugment: Learning Compositional Augmentation Policy for Text Classification},
  author={Shuhuai Ren, Jinchao Zhang, Lei Li, Xu Sun, Jie Zhou},
  booktitle={EMNLP},
  year={2021}
}

License

MIT

text-autoaugment's People

Contributors

renshuhuai-andy avatar wolvecap avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.