Comments (22)
Just as a supplement, the figure shows detailed time costs for each task when 4 Raspberry Pis run Llama2-7B-Q40. As you can see, how much time the aforementioned functionk costs. And if you can make these computations parallel according to the 'Mature solution', then the time will decrease nearly linearly with the number of devices increasing. @b4rtaz
from distributed-llama.
Nice measurments! It seems multiheadAtt
is super slow.
@zhengpeirong please check the 0.3.1 version. Now all tasks are executed in parallel so it should be a bit better.
from distributed-llama.
@b4rtaz The 'qkv' has been reverted. Do you plan to deal with this issue? Not only the 'MulHead' costs time, but also the 'Finalize' costs a big portion of time.
from distributed-llama.
@zhengpeirong yes I know. The qkv
seems be quite good optimalized if you look at the rest layers. Still the qkv
may be improved in this way as you suggested in the first post. I didn't have time to read it yet.
With the finalize
layer is that problem, the output of this layer is large (vocabSize
) and I think it's not a good idea to synchronize it. But maybe it could be optimised in this way that a worker would use the sampler on slice of own output, then the root node could merge it somehow. Different samplers would require a different logic for merging, but it looks doable (for example sample_argmax
looks super easy).
Yes, I want to keep working on this project. More hands are welcome. :-)
from distributed-llama.
@b4rtaz Thanks for your persistence and endeavor.
- The
qkv
can be optimized, and all you need to read is the "3. Model Parallel Transformers" of the paper "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism". - The
finalize
can be optimized as your design. - The transfer time for the FFN layer can be reduced from 2 to 1 by utilizing the method in the paper "Megatron-LM".
If you combine all those mechanisms, the non-parallel functions will be optimized! Here is the draft workflow:
TransformerArch buildLlama2Arch(TransformerSpec* spec) {
TransformerArch a;
// Inference
a.I(sendPoke, TASK_TYPE_TRANSFER);
for (int i = 0; i < spec->nLayers; i++) {
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE); // Combine the existing llamaRmsAtt and llamaRmsAttNorm
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); // Quantization
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); // Sending
a.I(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
a.I(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
a.I(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaSyncAtt, TASK_TYPE_TRANSFER); // First communication time-consuming
a.I(llamaDequantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaMergeAtt, TASK_TYPE_INFERENCE); // Merge all attention matrices
a.I(llamaRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
a.I(llamaFfn, TASK_TYPE_INFERENCE); // Compute SwiGLU activation
a.I(llamaFfn2, TASK_TYPE_INFERENCE); // Compute the second FFN
a.I(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaSyncFfn2, TASK_TYPE_TRANSFER); // Second communication time-consuming
a.I(llamaDequantizeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaMergeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaNextBlock, TASK_TYPE_INFERENCE);
}
a.I(llamaRmsFinal, TASK_TYPE_INFERENCE);
a.I(llamaRmsFinalNorm, TASK_TYPE_INFERENCE);
a.I(llamaLogits, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
a.I(llamaSyncLogits, TASK_TYPE_TRANSFER);
a.I(llamaDequantizeLogits, TASK_TYPE_INFERENCE);
a.I(llamaMergeLogits, TASK_TYPE_INFERENCE);
// Worker
for (int i = 0; i < spec->nLayers; i++) {
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
a.W(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
a.W(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
a.W(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.W(llamaSyncAtt, TASK_TYPE_TRANSFER);
a.W(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
a.W(llamaFfn, TASK_TYPE_INFERENCE);
a.W(llamaFfn2, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
a.W(llamaSyncFfn2, TASK_TYPE_TRANSFER);
a.W(llamaNextBlock, TASK_TYPE_INFERENCE);
}
a.W(llamaLogits, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
a.W(llamaSyncLogits, TASK_TYPE_TRANSFER);
return a;
}
I hope this repo can catch up with the state-of-the-art algorithm as soon as possible~~
from distributed-llama.
The optimized result will be only 72% of the original generated time!!! It's 1.39x acceleration than this version.
I have roughly computed the optimized result. Specifically, the main transfer time only happens twice and the workload for the root node is divided among 4 workers.
from distributed-llama.
You did a great job! Have you considered opening-source this method for your branch?
from distributed-llama.
Related Issues (13)
- Turing RK1 compute module results HOT 4
- WebAssembly version HOT 1
- Can I use Ollama model HOT 1
- How about the multi-core support of stand-alone dual-socket motherboards? HOT 1
- Hi, do you know why the synchronization time from 4pi to 8pi suddenly increases? HOT 15
- Need help in set up all the devices
- Compiling error related to include of <ctime> HOT 1
- Assertion `d % nSlices == 0' failed. HOT 2
- To support Hugging Face model HOT 10
- Will this awesome proj consider supporting GPU acceleration? HOT 1
- converter.py OOM while converting llama-2-7b weights on my Raspberryi Pi 5 HOT 2
- Master process crashes running out of memory on a 8 GB RPi 5 HOT 15
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from distributed-llama.