Git Product home page Git Product logo

victarry / transformerengine Goto Github PK

View Code? Open in Web Editor NEW

This project forked from nvidia/transformerengine

0.0 0.0 0.0 3.71 MB

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

License: Apache License 2.0

Shell 0.24% C++ 12.27% Python 52.83% C 2.08% Cuda 32.26% CMake 0.32%

transformerengine's Introduction

License

Transformer Engine

Quickstart | Installation | User Guide | Examples | FP8 Convergence | Integrations | Release notes

Latest News

H200

What is Transformer Engine?

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference. TE provides a collection of highly optimized building blocks for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your framework-specific code. TE also includes a framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.

As the number of parameters in Transformer models continues to grow, training and inference for architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for many deep learning models. Using mixed-precision training, which combines single-precision (FP32) with lower precision (e.g. FP16) format when training a model, results in significant speedups with minimal differences in accuracy as compared to FP32 training. With Hopper GPU architecture FP8 precision was introduced, which offers improved performance over FP16 with no degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is not available natively in frameworks today.

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly simplifying mixed precision training for users.

Highlights

  • Easy-to-use modules for building Transformer layers with FP8 support
  • Optimizations (e.g. fused kernels) for Transformer models
  • Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs
  • Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later

Examples

PyTorch

JAX

Flax

Installation

Pre-requisites

  • Linux x86_64
  • CUDA 11.8+ for Hopper and CUDA 12.1+ for Ada
  • NVIDIA Driver supporting CUDA 11.8 or later
  • cuDNN 8.1 or later
  • For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.

Docker

The quickest way to get started with Transformer Engine is by using Docker images on NVIDIA GPU Cloud (NGC) Catalog. For example to use the NGC PyTorch container interactively,

Where 23.10 is the container version. For example, 23.10 for the October 2023 release.

pip

To install the latest stable version of Transformer Engine,

This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).

From source

See the installation guide.

Compiling with FlashAttention-2

Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.

It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see bug), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting MAX_JOBS=1 in the environment to circumvent the issue. If the errors persist, install a supported version of FlashAttention-1 (v1.0.6 to v1.0.9).

Note that NGC PyTorch 23.08+ containers include FlashAttention-2.

FP8 Convergence

FP8 has been tested extensively across different model architectures and configurations and we found no significant difference between FP8 and BF16 training loss curves. FP8 has also been validated for accuracy on downstream LLM tasks (e.g. LAMBADA and WikiText). Below are examples of models tested for convergence across different frameworks.

Model Framework Source
T5-770M

JAX/T5x

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-and-performance
MPT-1.3B

Mosaic Composer

https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1
GPT-5B

JAX/Paxml

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results
GPT-5B

NeMo Framework

Available on request
LLama2-7B

Alibaba Pai

https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ
T5-11B

JAX/T5x

Available on request
MPT-13B

Mosaic Composer

https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8
GPT-22B

NeMo Framework

Available on request
LLama2-70B

Alibaba Pai

https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ
GPT-175B

JAX/Paxml

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results

Integrations

Transformer Engine has been integrated with popular LLM frameworks such as:

Contributing

We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests, follow the guidelines outlined in the CONTRIBUTING.rst guide.

Papers

Videos

transformerengine's People

Contributors

asfiyab-nvidia avatar cyanguwa avatar denera avatar erhoo82 avatar galagam avatar hugo-syn avatar jeng1220 avatar kaixih avatar ksivaman avatar marks101 avatar mingxu1067 avatar minitu avatar nouiz avatar nzmora-nvidia avatar oleg-goncharov avatar ptrendx avatar quentin-anthony avatar rachitgarg91 avatar sanandaraj5597 avatar sbhavani avatar schetlur-nv avatar sudhakarsingh27 avatar timmoon10 avatar tom-zheng avatar trevor-m avatar vasunvidia avatar victarry avatar wong4j avatar xrennvidia avatar zlsh80826 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.