Comments (4)
Background
In SuperDuperDB, the scenarios for LLM training could be:
- A vast amount of document information for continuous pre-training.
- Data with specific purposes for Supervised Fine Tuning (SFT).
- Other scenarios to be determined.
The advantages of training on SuperDuperDB for the above scenarios include:
- Data Security: Sensitive data remains within the database.
- Data Utilization: Training data can be managed through the database, replacing traditional file formats or other cloud storage (like HuggingFace datasets). Due to the textual nature of LLM training data, MongoDB offers significant advantages.
- Data Loading: Data need not be exported but is loaded through a direct database connection using the specialized Dataset mechanism of SuperDuperDB.
- Trainer reuse: Training configurations are recorded in the Metadata Store and can be reloaded using
db.load("trainer", 'llm_trainer_xxx')
for reproducibility or model iteration.
For Lora-based fine-tuning scenarios, the most frequent task is model training on single/multiple GPUs. Different training needs require different technologies:
- Scalable training: ray
- Distributed training, insufficient memory: deepspeed, fsdp
- Remote training task submission: ray
Use cases
Local Training with Simple Configuration
from superduperdb.ext.llm.train import LLMTrainer
from superduperdb import superduper
db = superduper(mongodb_uri)
prompt_template = "Below is an instruction that describes a task. Write a response that appropriately completes the
request.\n\n### Instruction:\n{X}\n\n### Response:\n{Y}"
trainer = LLMTrainer(model="facebook/opt-125m", bits=4, qlora=True, prompt_template=prompt_template)
trainer.fit(X="instruction", Y="output", db=db, select=Collection("docs").find())
Local Training with More Granular Configuration
from superduperdb.ext.llm.train import LLMTrainer, LoraArguments, TrainingArguments, DataArguments, ModelArguments
from superduperdb import superduper
db = superduper(mongodb_uri)
model_args = ModelArguments(model="facebook/opt-125m", ....)
lora_args = LoraArguments(bits=4, qlora=True, ...)
training_args = TrainingArguments(output='./outputs', learning_rate=1e-5, ...)
data_args = DataArguments(cache_size=10000, max_length=1024, prompt_template=prompt_template)
trainer = LLMTrainer(model_args, lora_args, training_args, data_args)
trainer.fit(X="instruction", Y="output", db=db, select=Collection("docs").find())
Deepspeed Configuration (FSDP configuration is similar but currently, we only need to support one, both have similar functions and deepspeed is reportedly better)
from superduperdb.ext.llm.train import LLMTrainer, DeepspeedConfig
from superduperdb import superduper
db = superduper(mongodb_uri)
deepspeed_config = DeepspeedConfig(....)
trainer = LLMTrainer(model="facebook/opt-125m", deepspeed_config=deepspeed_config)
trainer.fit(X="instruction", Y="output", db=db, select=Collection("docs").find())
Ray Configuration
from superduperdb.ext.llm.train import LLMTrainer, RayConfig
from superduperdb import superduper
db = superduper(mongodb_uri)
ray_config = RayConfig(....)
trainer = LLMTrainer(model="facebook/opt-125m", ray_config=ray_config)
trainer.fit(X="instruction", Y="output", db=db, select=Collection("docs").find())
Custom Training Not supported yet
Implementation Methods
Lora
Implement qlora using the PEFT package, with no special notes.
Data Saving
- Save checkpoints locally.
- Save checkpoints to ArtifactStore (lower priority, to be supplemented after other completions).
- Save checkpoints to S3 (used when remote with Ray, Ray also supports this feature, should be reusable).
Weight Merging
Currently, vLLM has not supported reading lora weights. To deploy the model, we need to merge the base model and lora weights, which can be done by adding a merge script. This will be discarded after vLLM supports lora adapters.
Current lora model inference plan:
- Use native transformers support, which might perform slightly worse.
- Load the model using vllm after merging.
Switch to vllm's solution after vLLM supports Multi-Lora.
Datasets
Currently, superduperdb
has two relevant datasets:
- QueryDataset: For loading small datasets, read in full.
- CachedQueryDataset: For loading large datasets, read in batches.
Datasets Handling in Multiprocessing
In frameworks like deepspeed for distributed training, multiple processes are initiated. Thus, special handling of the dataset might be required, like introducing local_rank for data fetching, etc. This needs to be adapted in practice.
Datasets Handling in Ray Remote Training
Since Ray remote processing may not connect to the database, Ray compute clusters are assumed to be unable to connect to databases. Thus, it should be transformed into Ray Dataset.
Launcher
For distributed training, a specific launcher like deepspeed, accelerate, etc., is generally required. Different scenarios are discussed:
- Single GPU training or Jupyter notebook. Here, Python can be used as the Launcher.
- Multi-GPU training requires deepspeed, accelerate as the Launcher.
- Ray/Ray+deepspeed distributed training, Python can be directly used as the Launcher.
Therefore, the specificity here is with distributed training. Hence, there are two plans:
- Perform distributed training without configuring Ray, using Deepspeed as the Launcher, and freely using others.
- Default to using Ray for distributed training, initializing local Ray clusters or connecting to remote clusters with ray.init().
The current suggestion is to proceed with 1, as it just requires changing the Launcher at startup. After adapting Ray distributed training smoothly and resolving significant bugs, it can be directly switched to 2. The switch cost is low as Ray support is also to be done and should only require minimal code changes later.
Experimental Records
Trainer as a Component
can be recorded in the metadata_store
for experiment reproducibility and model iteration. The following operations will record it:
trainer.fit(db=db, ....)
# or
db.add(trainer)
For experiment reproducibility/model iteration:
trainer = db.load("trainer", 'xxxx')
trainer.fit(....)
All configuration information will be recorded in the metadata.
Training Framework/Code
Currently, a similar reusable framework is https://github.com/huggingface/trl, a basic training framework mainly encapsulating a lot of details for the LLM scenario, providing RLHF training. We are not using it currently. Therefore, we will implement the related code ourselves.
Demo Code
Deepspeed: https://github.com/SuperDuperDB/llama2-experiments/blob/main/new_experiments/train_new.py
deepspeed + Ray: https://github.com/SuperDuperDB/llama2-experiments/blob/main/new_experiments/gptj_deepspeed_fine_tuning.ipynb
Discussion
Whether to make it as a separate repo or implement it directly in superduperdb repo
For example
pip install superduperdb-llm-trainer
from superduperdb.
...
from superduperdb.
Thanks @jieguangzhou. Would it make sense to go directly to support for ray
or what do you think?
In your tasks-list you write "Ray and deepspeed to run local multi GPU training". Couldn't this also be remote?
from superduperdb.
Thanks @jieguangzhou. Would it make sense to go directly to support for
ray
or what do you think?
This is feasible, but for users running training locally, ray is not actually needed, which may increase system complexity.
For users of the open-source version, it may not be necessary to connect ray in most cases. But for the commercial version, it will be better to be completely based on Ray, because all training will be run on the Ray cluster.
WDYT? @blythed
In your tasks-list you write "Ray and deepspeed to run local multi GPU training". Couldn't this also be remote?
Wrong typing, it’s remote
from superduperdb.
Related Issues (20)
- [Task] Experiment with new code for exploring any bugs in the code.
- [APP] Create components to ship easy to adapt applications
- [APP] Save a component template with variables without triggering jobs HOT 1
- [APP] Create an "application" component which uses a template
- [APP] Make component templates portable
- [APP] Add `/db/apply_template` endpoint to invoke a template
- Remove the `info` variable from the datatype
- [DOC]: ImportError: cannot import name 'LLM' from 'superduperdb.ext.transformers' HOT 9
- [BUG] URL anonymization on snowflake
- Provide an interface for data viewing that allows interactive data observation functionality.
- [BUG-0.2.0]: Errors When Inserting Artifact Encodable Type Data HOT 1
- [Version Info] Rename superduperdb
- Improve superduperdb component serialised console print with verbose controls.
- [CORE-0.3] Improvements to core functionality
- [CORE-0.3] Remove `@dc.dataclass` wrappers everywhere HOT 1
- [CORE-0.3] Serialize classes so that they are accessible from all services
- [DOC]: Missing DevKit installation in CONTRIBUTING file HOT 2
- [Bug] Fix order for information printing
- [MISC] Encode the component’s artifact using the schema of the component.
- Rollback the design of the Dataset
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from superduperdb.