Git Product home page Git Product logo

Comments (7)

pseudotensor avatar pseudotensor commented on August 25, 2024

Some old notes from slack last week:

Right now, nothing in torch/huggingface directly can be used to do flash attention. One would need to swap-out the layer, which is possible as this is what gpt-neox repo does. I'll have to look more carefully with this approach to see how to do it, similar to how the other vicuna repo does for llama.
And alternative is to use gpt-neox repo directly with their training code, which is probably fine. I installed all their dependencies and nothing had issues.

source ~/.bashrc.mamba
mamba create -n gptneox
conda activate gptneox
mamba install python=3.8 -y
mamba install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia -y
cd gpt-neox/
pip install -r requirements/requirements.txt
mamba install cudatoolkit-dev=11.7 cudatoolkit=11.7 -c conda-forge -c nvidia -y
unset CUDA_HOME
python ./megatron/fused_kernels/setup.py install
pip install -r ./requirements/requirements-flashattention.txt
cd ..
git clone https://github.com/EleutherAI/DeeperSpeed.git
cd DeeperSpeed
./install.sh

cuda 11.7 required.

from h2ogpt.

pseudotensor avatar pseudotensor commented on August 25, 2024

WIP for neox using flash in huggingface transformers, but no work for last 3 months, so probably dead: https://github.com/conceptofmind/flash-gpt

from h2ogpt.

pseudotensor avatar pseudotensor commented on August 25, 2024

Amazon thing: https://aws.amazon.com/blogs/machine-learning/new-performance-improvements-in-amazon-sagemaker-model-parallel-library/

To help our customers further minimize training costs and accelerate time-to-market, we are thrilled to introduce two new performance improvements in SageMaker model parallel — SMDDP Collectives and FlashAttention. SMDDP Collectives is the most performant collective library on AWS infrastructure for large model training offered by SageMaker distributed data parallel library. FlashAttention is introduced in Dao et al., which re-implements the attention mechanism in an IO-aware manner, reducing the memory bandwidth requirement and saving on attention speed and memory footprint. These two components collectively push our sharded data parallel technique to be 30.58% faster when training a 100B parameter GPT-NeoX model on 32 p4d.24xlarge instances. For customers who are already using sharded data parallel on supported models, no code changes are necessary to benefit from the performance boost offered by these latest features.

So maybe we should use sagemaker. I noticed this before somewhere else I think.
But unsure how compatible with other weights e.g. huggingface

100B parameter GPT-NeoX model on 32 p4d.24xlarge instances

from h2ogpt.

pseudotensor avatar pseudotensor commented on August 25, 2024

You can use the same install above to then make llama use flash attention using the wrappers/patches from vicunda model:
https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train_mem.py#L5
So we can already do that for llama case if we are interested.

from h2ogpt.

pseudotensor avatar pseudotensor commented on August 25, 2024

EleutherAI/gpt-neox#725

from h2ogpt.

arnocandel avatar arnocandel commented on August 25, 2024

^ specifically
https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/flash_attention.py

from h2ogpt.

arnocandel avatar arnocandel commented on August 25, 2024

#128

from h2ogpt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.