Git Product home page Git Product logo

token_selection's Introduction

VCAS: variance-controlled adaptive sampling

Official implementation of the paper: Efficient Backpropagation with Variance Controlled Adaptive Sampling accepted by ICLR 2024.

Abstract

VCAS is an efficient unbiased sampling method to accelerate back propagation (BP). It can achieve an up to 73.87% FLOPs reduction of BP with lossless performance under theoretical guarantee. Example Results.

VCAS algorithm diagram

Installation

Requirements: Python >= 3.7 + CUDA >= 11.0 + torch >= 1.12.0 + transformers >= 4.21.0

To install VCAS only, run the following:

pip install -v -e .

To further run the examples, run the following :

conda create -n vcas python=3.9
conda activate vcas
pip install -r requirements.txt # install dependencies
pip install -v -e . # install vcas

Quick Start

There are four steps to use VCAS in your project:

  1. Add VcasSampleArguments to the argument parser
  2. Initialize VcasSampleScheme with VcasSampleArguments
  3. Process the original model with VcasModelProcessor
  4. Substitute the original huggingface Trainer with VcasTrainer
from vcas import VcasSampleArguments, VcasSampleScheme, VcasModelProcessor, VcasTrainer

......

# CHANGES #1: Add VcasSampleArguments to the argument parser 
parser = HfArgumentParser((..., VcasSampleArguments))
..., sample_args = parser.parse_args_into_dataclasses()

# CHANGES #2: Initialize VcasSampleScheme with VcasSampleArguments
sample_scheme = VcasSampleScheme(sample_args)

# CHANGES #3: Process the original model with VcasModelProcessor
model = ...
act_sampling_repetend = ... # basic block of the model, after which VCAS insert activation sampler, eg. BertLayer in BertModel
processor = VcasModelProcessor(model, act_sampling_repetend, sample_scheme)
processor.process()

# CHANGES #4: Substitute the original huggingface Trainer with VcasTrainer
trainer = VcasTrainer(..., sample_scheme=sample_scheme)

......

For research studies, you can refer to the sampling layers in vcas/layers and the adaptive sampling strategy in vcas/vcas_trainer.py. Help yourself to integrate them into your model and pipeline directly.

Example Usage

We provide an example of BERT fintuning on SST-2 dataset in examples/run_glue.py. Run the following:

cd examples
bash run.sh

We record the running results above and compare it with the exact training baseline (run with bash run_baseline.sh) on NVIDIA 2080Ti. Check the results through https://wandb.ai/thuwzt/vcas.

Help yourself to compare examples/run_glue.py with examples/run_glue_baseline.py to see the modifications and change model_name_or_path and task_name in examples/run.sh to try other models and datasets using VCAS.

Modifying Sampling Hyperparameters

VCAS provides a set of hyperparameters which are proven to be insensitive to the performance. However, you can still modify them to fit your specific task. Please check vcas/sample_args.py for more details.

The main hyperparameters are:

  • act_var_tau: The acceptable ratio of the activation sampling variance to the SGD variance (default: 0.025), higher value means more aggressive sampling.
  • w_var_tau: The acceptable ratio of the weight sampling variance to the SGD variance (default: 0.025), higher value means more aggressive sampling.
  • cal_var_freq: The frequency to calculate variance (default: 100), higher value brings less overhead but less sufficient adaptation. We recommend to keep it at least 1/50 of total training steps for thorough adaptation.

token_selection's People

Contributors

thuwzt avatar daqige avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.