Git Product home page Git Product logo

Comments (4)

jieguangzhou avatar jieguangzhou commented on June 6, 2024

Background

In SuperDuperDB, the scenarios for LLM training could be:

  1. A vast amount of document information for continuous pre-training.
  2. Data with specific purposes for Supervised Fine Tuning (SFT).
  3. Other scenarios to be determined.

The advantages of training on SuperDuperDB for the above scenarios include:

  1. Data Security: Sensitive data remains within the database.
  2. Data Utilization: Training data can be managed through the database, replacing traditional file formats or other cloud storage (like HuggingFace datasets). Due to the textual nature of LLM training data, MongoDB offers significant advantages.
  3. Data Loading: Data need not be exported but is loaded through a direct database connection using the specialized Dataset mechanism of SuperDuperDB.
  4. Trainer reuse: Training configurations are recorded in the Metadata Store and can be reloaded using db.load("trainer", 'llm_trainer_xxx') for reproducibility or model iteration.

For Lora-based fine-tuning scenarios, the most frequent task is model training on single/multiple GPUs. Different training needs require different technologies:

  1. Scalable training: ray
  2. Distributed training, insufficient memory: deepspeed, fsdp
  3. Remote training task submission: ray

Use cases

Local Training with Simple Configuration

from superduperdb.ext.llm.train import LLMTrainer
from superduperdb import superduper

db = superduper(mongodb_uri)

prompt_template = "Below is an instruction that describes a task. Write a response that appropriately completes the 
request.\n\n### Instruction:\n{X}\n\n### Response:\n{Y}"

trainer = LLMTrainer(model="facebook/opt-125m", bits=4, qlora=True, prompt_template=prompt_template)
trainer.fit(X="instruction", Y="output", db=db, select=Collection("docs").find())

Local Training with More Granular Configuration

from superduperdb.ext.llm.train import LLMTrainer, LoraArguments, TrainingArguments, DataArguments, ModelArguments
from superduperdb import superduper

db = superduper(mongodb_uri)

model_args = ModelArguments(model="facebook/opt-125m", ....)
lora_args = LoraArguments(bits=4, qlora=True, ...)
training_args = TrainingArguments(output='./outputs', learning_rate=1e-5, ...)
data_args = DataArguments(cache_size=10000, max_length=1024, prompt_template=prompt_template)

trainer = LLMTrainer(model_args, lora_args, training_args, data_args)
trainer.fit(X="instruction", Y="output", db=db, select=Collection("docs").find())

Deepspeed Configuration (FSDP configuration is similar but currently, we only need to support one, both have similar functions and deepspeed is reportedly better)

from superduperdb.ext.llm.train import LLMTrainer, DeepspeedConfig
from superduperdb import superduper

db = superduper(mongodb_uri)

deepspeed_config = DeepspeedConfig(....)

trainer = LLMTrainer(model="facebook/opt-125m", deepspeed_config=deepspeed_config)
trainer.fit(X="instruction", Y="output", db=db, select=Collection("docs").find())

Ray Configuration

from superduperdb.ext.llm.train import LLMTrainer, RayConfig
from superduperdb import superduper

db = superduper(mongodb_uri)

ray_config = RayConfig(....)

trainer = LLMTrainer(model="facebook/opt-125m", ray_config=ray_config)
trainer.fit(X="instruction", Y="output", db=db, select=Collection("docs").find())

Custom Training Not supported yet

Implementation Methods

Lora

Implement qlora using the PEFT package, with no special notes.

Data Saving

  1. Save checkpoints locally.
  2. Save checkpoints to ArtifactStore (lower priority, to be supplemented after other completions).
  3. Save checkpoints to S3 (used when remote with Ray, Ray also supports this feature, should be reusable).

Weight Merging

Currently, vLLM has not supported reading lora weights. To deploy the model, we need to merge the base model and lora weights, which can be done by adding a merge script. This will be discarded after vLLM supports lora adapters.

Current lora model inference plan:

  1. Use native transformers support, which might perform slightly worse.
  2. Load the model using vllm after merging.

Switch to vllm's solution after vLLM supports Multi-Lora.

Datasets

Currently, superduperdb has two relevant datasets:

  1. QueryDataset: For loading small datasets, read in full.
  2. CachedQueryDataset: For loading large datasets, read in batches.

Datasets Handling in Multiprocessing

In frameworks like deepspeed for distributed training, multiple processes are initiated. Thus, special handling of the dataset might be required, like introducing local_rank for data fetching, etc. This needs to be adapted in practice.

Datasets Handling in Ray Remote Training

Since Ray remote processing may not connect to the database, Ray compute clusters are assumed to be unable to connect to databases. Thus, it should be transformed into Ray Dataset.

Launcher

For distributed training, a specific launcher like deepspeed, accelerate, etc., is generally required. Different scenarios are discussed:

  1. Single GPU training or Jupyter notebook. Here, Python can be used as the Launcher.
  2. Multi-GPU training requires deepspeed, accelerate as the Launcher.
  3. Ray/Ray+deepspeed distributed training, Python can be directly used as the Launcher.

Therefore, the specificity here is with distributed training. Hence, there are two plans:

  1. Perform distributed training without configuring Ray, using Deepspeed as the Launcher, and freely using others.
  2. Default to using Ray for distributed training, initializing local Ray clusters or connecting to remote clusters with ray.init().

The current suggestion is to proceed with 1, as it just requires changing the Launcher at startup. After adapting Ray distributed training smoothly and resolving significant bugs, it can be directly switched to 2. The switch cost is low as Ray support is also to be done and should only require minimal code changes later.

Experimental Records

Trainer as a Component can be recorded in the metadata_store for experiment reproducibility and model iteration. The following operations will record it:

trainer.fit(db=db, ....)
# or
db.add(trainer)

For experiment reproducibility/model iteration:

trainer = db.load("trainer", 'xxxx')
trainer.fit(....)

All configuration information will be recorded in the metadata.

Training Framework/Code

Currently, a similar reusable framework is https://github.com/huggingface/trl, a basic training framework mainly encapsulating a lot of details for the LLM scenario, providing RLHF training. We are not using it currently. Therefore, we will implement the related code ourselves.

Demo Code
Deepspeed: https://github.com/SuperDuperDB/llama2-experiments/blob/main/new_experiments/train_new.py

deepspeed + Ray: https://github.com/SuperDuperDB/llama2-experiments/blob/main/new_experiments/gptj_deepspeed_fine_tuning.ipynb

Discussion

Whether to make it as a separate repo or implement it directly in superduperdb repo

For example

pip install superduperdb-llm-trainer

from superduperdb.

jieguangzhou avatar jieguangzhou commented on June 6, 2024

...

from superduperdb.

blythed avatar blythed commented on June 6, 2024

Thanks @jieguangzhou. Would it make sense to go directly to support for ray or what do you think?

In your tasks-list you write "Ray and deepspeed to run local multi GPU training". Couldn't this also be remote?

from superduperdb.

jieguangzhou avatar jieguangzhou commented on June 6, 2024

Thanks @jieguangzhou. Would it make sense to go directly to support for ray or what do you think?

This is feasible, but for users running training locally, ray is not actually needed, which may increase system complexity.
For users of the open-source version, it may not be necessary to connect ray in most cases. But for the commercial version, it will be better to be completely based on Ray, because all training will be run on the Ray cluster.
WDYT? @blythed

In your tasks-list you write "Ray and deepspeed to run local multi GPU training". Couldn't this also be remote?

Wrong typing, it’s remote

from superduperdb.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.