Git Product home page Git Product logo

multi-lora-fine-tune's Introduction

m-LoRA: Efficient LLM Model Fine-Tune via Multi-LoRA Optimization

m-LoRA (a.k.a Multi-Lora Fine-Tune) is an open-source framework for fine-tuning Large Language Models (LLMs) using the efficient multiple LoRA/QLoRA methods. Key features of m-LoRA include:

  • Efficient LoRA/QLoRA: Optimizes the fine-tuning process, significantly reducing GPU memory usage by leveraging a shared frozen-based model.

  • Multiple LoRA Adapters: Support for concurrent fine-tuning of multiple LoRA/QLoRA adapters.

Contents

Updates

  • Support multiple LLaMA2 fine-tuning
  • Support multiple ChatGLM fine-tuning
  • Support multiple LLaMA fine-tuning
  • On the way, Baichuan

Models

Model Model size
ChatGLM 6B
ChatGLM2 6B/12B
ChatGLM3 6B
LLaMA 7B//13B/33B/65B
LLaMA-2 7B/13B/70B
Baichuan 7B/13B
Baichuan2 7B/13B

Example: Use our system to improve the LLaMa-2 fine-tuning with less resources https://www.kaggle.com/code/rraydata/multi-lora-example/notebook

Overview

m-LoRA is a high-throughput LLM fine-tuning framework based on LoRA and QLoRA, compatible with HuggingFace-Transformers LLaMA Models and ChatGLM Models.

This picture shows the basic principle of LoRA and Multi-LoRA.

The system overview of m-LoRA is as follows.

m-LoRA requires PyTorch and NVIDIA CUDA compatible GPUs.

Main Contribution

  • Introduces the Multi-LoRA method, capable of enabling the sharing of pre-trained model weights during the fine-tuning process of large language models;
  • Proposes a task scheduling algorithm to enhance the overall throughput of the task training process and reduce total training latency;
  • Builds upon the above by implementing m-LoRA, a high-throughput large language model fine-tuning framework based on LoRA and QLoRA;
  • Evaluates m-LoRA in experiments against existing systems, confirming that m-LoRA effectively utilizes system computing resources, thereby improving training throughput and reducing training latency compared to current systems.

Experiment Results

Environment: NVIDIA RTX A6000 with Intel Xeon Silver 4314 on Ubuntu 22.04.3

Baseline: We utilized the widely adopted Alpaca-LoRA as a foundation. On a single GPU, we independently ran multiple Alpaca-LoRA processes in parallel (marked as Baseline@Alpaca-Parallel) and sequentially (marked as Baseline@Alpaca-Seq), forming two baseline methods for the experiments. We test this on A100, and rest of results are based on the same GPU configure.

Training Latency and Throughput

Method Latency Throughput
Baseline@Alpaca-Seq 10.51h 608.41 token/s
Baseline@Alpaca-Parallel 9.85h 649.30 token/s
m-LoRA 9.46h 674.58 token/s

We conducted four identical fine-tuning jobs with same dataset and same hyper-parameters, incorporating two baselines and m-LoRA. During the experimental process, we collected the completion times for each task in the baseline methods and calculated the time taken by the slowest task as the Training Latency. As shown in Table, m-LoRA exhibits lower Training Latency compared to both baseline methods. Specifically, m-LoRA is 9.99% faster than Baseline@Alpaca-Seq and 3.92% faster than Baseline@Alpaca-Parallel.

Video Memory Usage

We conducted several fine-tuning jobs with same dataset and batch_size = {2,4, 6, 8}, incorporating Baseline@Alpaca-Parallel and m-LoRA.

Baseline@Alpaca-Parallel triggered OOM error after 3 parallel tasks when batch size = 8, while m-LoRA can handle twice that amount.

Batching Strategies

Method Training Latency Peak Memory Usage Average GPU Utilization Training Throughput
Baseline@Alpaca-Seq 27.73h 10.68GB 79.39% 653.35 token/s
m-LoRA@M1 36.82h 23.82GB 96.52% 672.54 token/s
m-LoRA@M2 39.14h 23.86GB 96.41% 671.28 token/s
m-LoRA@M3 22.97h 23.85GB 95.22% 674.41 token/s

We conducted four fine-tuning jobs with different dataset but same hyper-parameters, incorporating Baseline@Alpaca-Seq and m-LoRA.

During the experimental process, we collected following metrics:

  • Training Latency = Job completion time
  • Throughput = The number of passed tokens in model forward process / training latency
  • Memory Usage = Peak video memory usage
  • GPU Utilization = Average GPU utilization

All metrics are computed for each job. M1, M2, M3 represent three batch strategies of m-LoRA: Optimal-Fit, Trivial, and Fast-Fit. BASELINE denotes Baseline@Alpaca-Seq.

The Optimal-Fit strategy performs the best across all four metrics, while the other two strategies also outperform the baseline method other than training latency.

Use Cases:

  • Domain-Specific Fine-Tuning: This involves adapting a single model with various parameters particularly for one domain.
  • Cross-Domain Fine-Tuning: This method leverages the base model to fine-tune multiple models, each intended for a different domain.

Quickstart

Firstly, you should clone this repository and install dependencies:

# Clone Repository
git clone https://github.com/TUDB-Labs/multi-lora-fine-tune
cd multi-lora-fine-tune
# Install requirements
pip install -r requirements.txt

The mlora.py code is a starting point for finetuning on various datasets. Basic command for finetuning a baseline model on the Alpaca Cleaned dataset:

python mlora.py \
  --base_model yahma/llama-7b-hf \
  --config ./config/alpaca.json \
  --load_8bit

You can check the template finetune configuration in template folder.

For further detailed usage information, please use --help option:

python mlora.py --help

Demo on Colab

You can run finetune on Colab by following this example: Google Colab Example. Make sure to switch the runtime environment to GPU before running it.

Installation

You can also install m-LoRA into your environment:

# Optional but recommended
conda create -n mlora_env python=3.8
conda activate mlora_env
# Install requirements
pip install mlora

After installation, you can use m-LoRA directly in your code:

import mlora

Contributing

We welcome contributions to improve this repository! Please review the contribution guidelines before submitting pull requests or issues.

Fork the repository. Create a new branch for your feature or fix. Submit a pull request with a detailed explanation of your changes.

Citation

Please cite the repo if you use the code in this repo.

@misc{m-LoRA,
  author = {Zhengmao, Ye\textsuperscript{*} and Dengchun, Li\textsuperscript{*} and Jingqi, Tian and Tingfeng, Lan and Yanbo, Liang and Yexi, Jiang and Jie, Zuo and Hui, Lu and Lei, Duan and Mingjie, Tang},
  title = {m-LoRA: Efficient LLM Model Fine-tune and Inference via Multi-Lora Optimization},
  year = {2023},
  publisher = {GitHub},
  howpublished = {\url{https://github.com/TUDB-Labs/multi-lora-fine-tune}},
  note={\textsuperscript{*}: these authors contributed equally to this work.}
}

Copyright

Copyright © 2023 All Rights Reserved.

This project is licensed under the Apache 2.0 License.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

multi-lora-fine-tune's People

Contributors

mikecovlee avatar yezhengmao1 avatar trilarflagz avatar merlintang avatar lianxingao avatar qsimu avatar antlera avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.