Git Product home page Git Product logo

chain-of-hindsight-pytorch's Introduction

Chain of Hindsight in PyTorch & Huggingface Trainer

This is an unofficial implementation of Chain of Hindsight using PyTorch and Huggingface Trainer. The data loading script is directly taken from the original repo, and only the training part is re-written using PyTorch.

Installation

  • For pip,
pip install -r environment/requirements.txt
  • For conda,
conda env create -f environment/env.yml

Train

A shell script for training can be found in train.sh. It takes gpu device ids as inputs and passes it to CUDA_VISIBLE_DEVICES environment variable.

sh train.sh 0,1,2,3

To customize command line arguments, take a look at the arguments dataclasses used in the following files:

  • coh.coh_train.ExperimentArgs
  • coh.data.coh_data.CoHDataArgs
  • coh.trainer.CoHTrainArgs (this inherits from transformers.TrainingArguments)

Train LLaMA

Train script for LLaMA is also provided. The baseline script is:

sh llama_train.sh 0,1,2,3 ${LLAMA_PATH}

To use this script, you will need to have already downloaded LLaMa weights and converted it to pytorch weights using the convert script at huggingface transformers repo.

  • Relevant PR
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
    --input_dir /path/to/downloaded/llama/weights \
    --model_size 7B \
    --output_dir /output/path

DeepSpeed

To use DeepSpeed, you need nvcc with the correct version installed. Conda provides cuda-nvcc package, which is also included in env.yml. However, to use this, you need to set the CUDA_HOME environment variable to point to the conda environment (this is required for deepspeed JIT c++ compiler to point to the conda installed nvcc not the system-wide one). after creating the environment and activating it, set

export CUDA_HOME=/path/to/conda/envs/coh

Example deepspeed config files can be found in ds_config. They are directly taken from huggingface's deepspeed integration tutorial.

By default, train.sh uses deepspeed. llama_train.sh uses FSDP instead.

PEFT

To further enhance efficiency of training, PEFT lora is applied. Pass --use_lora into training arguments.

8-bit training

You can also use 8-bit training!

  • This is compatible with PEFT.
  • This is NOT compatible with DeepSpeed.
    • Need to use torchrun launcher instead of deepspeed launcher.

Notice

This repo diverges from the original repo's implementation in a few ways:

  1. The original repo does not have evaluation step.
  2. Here, no bos_token is prepended to the input_ids. This is because since the batching logic is chunk-wise, each sentence in a batch is not really a sentence.
  3. No weight_decay_mask is used.
  4. Forgetful Causal Masking is not applied.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.