Git Product home page Git Product logo

Comments (22)

zhengpeirong avatar zhengpeirong commented on May 31, 2024 1

image
Just as a supplement, the figure shows detailed time costs for each task when 4 Raspberry Pis run Llama2-7B-Q40. As you can see, how much time the aforementioned functionk costs. And if you can make these computations parallel according to the 'Mature solution', then the time will decrease nearly linearly with the number of devices increasing. @b4rtaz

from distributed-llama.

b4rtaz avatar b4rtaz commented on May 31, 2024

Nice measurments! It seems multiheadAtt is super slow.

@zhengpeirong please check the 0.3.1 version. Now all tasks are executed in parallel so it should be a bit better.

from distributed-llama.

zhengpeirong avatar zhengpeirong commented on May 31, 2024

@b4rtaz The 'qkv' has been reverted. Do you plan to deal with this issue? Not only the 'MulHead' costs time, but also the 'Finalize' costs a big portion of time.

from distributed-llama.

b4rtaz avatar b4rtaz commented on May 31, 2024

@zhengpeirong yes I know. The qkv seems be quite good optimalized if you look at the rest layers. Still the qkv may be improved in this way as you suggested in the first post. I didn't have time to read it yet.

With the finalize layer is that problem, the output of this layer is large (vocabSize) and I think it's not a good idea to synchronize it. But maybe it could be optimised in this way that a worker would use the sampler on slice of own output, then the root node could merge it somehow. Different samplers would require a different logic for merging, but it looks doable (for example sample_argmax looks super easy).

Yes, I want to keep working on this project. More hands are welcome. :-)

from distributed-llama.

zhengpeirong avatar zhengpeirong commented on May 31, 2024

@b4rtaz Thanks for your persistence and endeavor.

  1. The qkv can be optimized, and all you need to read is the "3. Model Parallel Transformers" of the paper "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism".
  2. The finalize can be optimized as your design.
  3. The transfer time for the FFN layer can be reduced from 2 to 1 by utilizing the method in the paper "Megatron-LM".

If you combine all those mechanisms, the non-parallel functions will be optimized! Here is the draft workflow:

TransformerArch buildLlama2Arch(TransformerSpec* spec) {
    TransformerArch a;

    // Inference

    a.I(sendPoke, TASK_TYPE_TRANSFER);
    for (int i = 0; i < spec->nLayers; i++) {
        a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE); // Combine the existing llamaRmsAtt and llamaRmsAttNorm
        a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); // Quantization
        a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); // Sending
        a.I(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
        a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
        a.I(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
        a.I(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
        a.I(llamaSyncAtt, TASK_TYPE_TRANSFER); // First communication time-consuming
        a.I(llamaDequantizeAtt, TASK_TYPE_INFERENCE);
        a.I(llamaMergeAtt, TASK_TYPE_INFERENCE); // Merge all attention matrices
        a.I(llamaRmfFfn, TASK_TYPE_INFERENCE);
        a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE);
        a.I(llamaQuantizeRmfFfn, TASK_TYPE_INFERENCE);
        a.I(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
        a.I(llamaFfn, TASK_TYPE_INFERENCE); // Compute SwiGLU activation
        a.I(llamaFfn2, TASK_TYPE_INFERENCE); // Compute the second FFN
        a.I(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
        a.I(llamaSyncFfn2, TASK_TYPE_TRANSFER); // Second communication time-consuming
        a.I(llamaDequantizeFfn2, TASK_TYPE_INFERENCE);
        a.I(llamaMergeFfn2, TASK_TYPE_INFERENCE);
        a.I(llamaNextBlock, TASK_TYPE_INFERENCE);
    }
    a.I(llamaRmsFinal, TASK_TYPE_INFERENCE);
    a.I(llamaRmsFinalNorm, TASK_TYPE_INFERENCE);
    a.I(llamaLogits, TASK_TYPE_INFERENCE);
    a.I(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
    a.I(llamaSyncLogits, TASK_TYPE_TRANSFER);
    a.I(llamaDequantizeLogits, TASK_TYPE_INFERENCE);
    a.I(llamaMergeLogits, TASK_TYPE_INFERENCE);

    // Worker

    for (int i = 0; i < spec->nLayers; i++) {
        a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
        a.W(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
        a.W(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
        a.W(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
        a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
        a.W(llamaSyncAtt, TASK_TYPE_TRANSFER);
        a.W(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
        a.W(llamaFfn, TASK_TYPE_INFERENCE);
        a.W(llamaFfn2, TASK_TYPE_INFERENCE);
        a.W(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
        a.W(llamaSyncFfn2, TASK_TYPE_TRANSFER);
        a.W(llamaNextBlock, TASK_TYPE_INFERENCE);
    }
    a.W(llamaLogits, TASK_TYPE_INFERENCE);
    a.W(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
    a.W(llamaSyncLogits, TASK_TYPE_TRANSFER);

    return a;
}

I hope this repo can catch up with the state-of-the-art algorithm as soon as possible~~

from distributed-llama.

zhengpeirong avatar zhengpeirong commented on May 31, 2024

The optimized result will be only 72% of the original generated time!!! It's 1.39x acceleration than this version.
I have roughly computed the optimized result. Specifically, the main transfer time only happens twice and the workload for the root node is divided among 4 workers.
image

from distributed-llama.

galenyu avatar galenyu commented on May 31, 2024

You did a great job! Have you considered opening-source this method for your branch?

from distributed-llama.

Related Issues (13)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.