Git Product home page Git Product logo

dongwullm's Introduction

Megatron Cookbook: Pre-training, Compressing, Extending, and Distilling Your LLMs

[中文版] [English]

Large Language Models (LLMs), pre-trained on massive textual data, have demonstrated a remarkably broad range of capabilities for various downstream tasks. However, training these models is challenging due to memory constraints. Therefore, our codebase, built upon the Megatron-LM project, implements techniques for training LLMs effectively. In addition to scripts for training large models, this codebase also explores methods for compressing LLMs, extending their context length, and distilling them using synthesized text generated by other pre-trained large models.

Key features of this codebase include:

  1. Pre-training and Finetuning: We offer scripts for creating datasets specifically for LLM training. Additionally, our codebase includes functionality for converting checkpoints between the Llama weights in Huggingface format and the Megatron-LM format. Metrics such as TFlops and token speed are reported using 8xA100-80GB devices.
  2. Compression: Utilizing structured pruning and fine-tuning with a small number of tokens, it supports pruning a Transformer pre-trained model to any desired size while retaining most of its performance.
  3. Context Length Extension: We utilize Position Interpolation (PI) as detailed in this paper to extend the context length of the llama2-7b model from 4096 to 8192 tokens. Additionally, we provide the Perplexity (PPL) test results on the Pile and PG19 datasets. We also include scripts for extending the context length of our own 13b LLM from 4096 to 32768 tokens.
  4. Distillation: We continually train our 13b model using synthesized text generated by the Qwen-72B and Deepseek-67B models. We include scripts for synthesized text generation, continual training along with the resulting performance metrics.

Contents

Setup

Similar to Megatron-LM, we strongly recommend using the release of NGC's PyTorch container with DGX nodes. To launch an instance of the PyTorch container, you can follow the steps below:

  1. Installing docker and nvidia-docker on your gpu machine.
  2. Execute the following Docker commands:
    docker pull nvcr.io/nvidia/pytorch:23.12-py3
    docker run --gpus all --shm-size=128g --net=host -dit --rm --name megatron -v /your_dir:/your_dir -v /root/.ssh:/root/.ssh nvcr.io/nvidia/pytorch:23.12-py3
  3. Install sentencepiece and nltk in your environment.

Training

To train you own llm, follow these essential steps:

  1. Data Processing: Convert your textual data into a binary format suitable for training.
  2. Checkpoint Conversion: If you plan to finetune a pre-trained LLM, such as Llama2-7b, you will need to convert weights from the Huggingface format to the Megatron format. Conversely, to simplify the inference process with your trained LLM, convert the weights from the Megatron format back to the Huggingface format.
  3. Pretraining: Pretraining your own model in an efficient way.

Data processing

The data preprocessing steps align with those outlined in Megatron-LM. Your training data should be formatted as loose JSON, with each line containing a single JSON object representing a text sample. Below is an example script for preparing data for Llama training:

python tools/preprocess_data.py \
    --input /Path/to/my-corpus.jsonl \
    --output-prefix /Path/to/my-corpus \
    --tokenizer-type Llama2Tokenizer \
    --tokenizer-model /Path/to/tokenizer.model \
    --append-eod \
    --workers 16

To combine multiple binary-format datasets into a single dataset, execute the following command:

python tools/merge_datasets.py \
    --input /Path/to/datasets/folder \
    --output-prefix /Path/to/merged/dataset 

Checkpoint conversion

Megatron-LM applies pipeline parallelism and tensor parallelism to enable llm training with limited memory. Sometimes we need to change the number of pipeline parallelism and tensor parallelism in our checkpoint. Here is an example:

python tools/checkpoint/util.py \
        --model-type GPT \
        --load-dir /Path/to/ckpts/Llama2-7b-tp1 \
        --save-dir /Path/to/ckpts/Llama2-7b-tp4 \
        --target-tensor-parallel-size 4 \
        --target-pipeline-parallel-size 1 \
        --megatron-path /Path/to/Megatron

To convert huggingface-format weight into megatron format, here is an example on Llama:

TP=1
HF_FORMAT_DIR=/Path/to/Llama-2-7b-hf
MEGATRON_FORMAT_DIR=/Path/to/Llama2-7b-tp1
TOKENIZER_MODEL=/Path/to/Llama-2-7b-hf/tokenizer.model

python tools/checkpoint/util.py \
--model-type GPT \
--loader llama2_hf \
--saver megatron \
--target-tensor-parallel-size ${TP} \
--load-dir ${HF_FORMAT_DIR} \
--save-dir ${MEGATRON_FORMAT_DIR} \
--tokenizer-model ${TOKENIZER_MODEL}

To convert huggingface-format weight into megatron format, you should first use the scripts above to convert the megatron checkpoint in pipeline parralism and tensor parralism in 1. Here is an example on Llama:

python tools/checkpoint_conversion/llama_checkpoint_conversion.py \
--convert_checkpoint_from_megatron_to_transformers \
--load_path "/Path/to/Llama2-7b-tp1" \
--save_path "/Path/to/Llama2-7b-hf" \
--target_params_dtype "bf16" \
--make_vocab_size_divisible_by 1 \
--print-checkpoint-structure \
--megatron-path /Path/to/Megatron

Examples for 31B/65B/108B/132B training

We provide scripts for training 31B, 65B, 108B and 132B llama-based LLMs. The TFlops and token speed on A100-SXM4-80G is reported in the table below:

31B 65B 108B 132B
TFLOP/s per GPU 161 161 174 177
Tokens / day 8$\times$A100-80g 0.59B 0.27B 0.17B 0.15B

LLM Compression

Using structured pruning and fine-tuning with a small number of tokens, this approach supports pruning a Transformer pre-trained model to any desired size while retaining most of its performance. For a transformer model, the parameter size is determined by layer_num, hidden_size, intermediate_size, num_attention_heads. With this code, you only need to set new values for drop_layers, hidden_size_remain, intermediate_size_remain, and num_attention_heads_remain to compress the model into a smaller size.

We used this codebase to compress llama2-13B to 7B and llama2-7B to 3.4B, followed by fine-tuning with 20B and 12B data, respectively. The compression results are as follows:

Parameter Settings:

Model layer_num hidden_size intermediate_size num_attention_heads ml loss
LLaMA2-13B 40 5120 13824 40 1.50
LLaMA2-7B 32 4096 11006 32 1.54
Pruned-7B 32 4096 11006 32 1.56 (20B tokens)
Pruned-3.4B 28 3072 8192 28 1.71 (12B tokens)

Recovery Curve

avatar

For detailed usage instructions and a comparison of the effects after pruning, please refer to Model Compression.

Context length extension

LLMs typically have fixed context lengths, such as 2048 for LLaMA models and 4096 for LLaMA2 models. However, these preset context lengths may not suffice for many downstream tasks that require longer context windows, like long conversations or extracting informatrion from long documents. Consequently, extending the context window of pretrained LLMs becomes essential. In this section, we provide tutorials on extending the context window from 4096 to up to 32768 using Position Interpolation. Remarkably, with just 1000 steps of continual training, Position Interpolation can achieve high-quality performance in long text modeling.

Examples for context length extension training

We experimented with two settings to extend the context length of llama-based models. The script for extending the 7B model from 4096 to 8192 tokens can be found here, and the script for extending the 13B model from 4096 to 32768 tokens can be found here. Below is a figure illustrating the loss comparison for the 7B model at 8192 context length, both with and without Position Interpolation. The continual training is conducted on the Pile dataset.

We conducted tests to compare the performance, in terms of Perplexity (PPL), between the original Llama2-7b and its variants—with and without Position Interpolation (PI)—on the Pile and PG-19 datasets. The results are presented below:

Pile evaluation dataset:

4096 8192
Original Llama2-7b 6.137 -
W/ PI + fine-tune 5.978 5.859
w/o PI + fine-tune 6.066 -

PG-19:

4096 8192
Original Llama2-7b 5.956 -
W/ PI + fine-tune 5.861 5.702
w/o PI + fine-tune 5.943 -

The results indicate that with position interpolation, just 1000 steps can achieve high quality in long-context language modeling.

When converting the Megatron-format checkpoint to the Huggingface-format, remember to adjust the freqs used in rotary embedding. For instance, when using a rotary-seq-len-interpolation-factor of 2, it's necessary to modify the corresponding function in modeling_llama.py:

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
	# t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) / 2.0 
        # Change t with rotary-seq-len-interpolation-factor 

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

Distillation

Distillation is an effective method for transferring the knowledge and capabilities acquired from the pre-training of a large language model to a smaller, custom model. In our approach, we utilize synthetic data generated by sufficiently pre-trained LLMs, such as Deepseek-67B and Qwen-72B, to train our own 13B model. This strategy has resulted in a performance improvement on the FinEval financial test dataset.

Generating synthetic data

To efficiently generate synthetic data, we utilize vllm for inference. We employ the prefixes of pretraining data as prompts for generating this synthetic data. The scripts used in this process are available here.

Results of distillation

We evaluated our distilled model on the FinEval dataset, and the results are reported in the table below:

model Accounting Certificate Economy Finance Average
base-model 43.60 44.61 40.09 49.83 44.91
distill on 3B tokens 40.00 49.70 42.02 47.54 45.17
distill on 10B tokens 43.27 47.90 38.64 51.14 45.87

Credits

The following repositories are used in Megatron-Cookbook, either in close to original form or as an inspiration:

Megatron-LM

Megatron-LLaMA

Pruning-LLMs

dongwullm's People

Contributors

dwzq-com-cn avatar pinzhengwang322 avatar yyding1 avatar jordddan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.