pytorch-labs / ao Goto Github PK
View Code? Open in Web Editor NEWtorchao: PyTorch Architecture Optimization (AO). A repository to host AO techniques and performant kernels that work with PyTorch.
License: BSD 3-Clause "New" or "Revised" License
torchao: PyTorch Architecture Optimization (AO). A repository to host AO techniques and performant kernels that work with PyTorch.
License: BSD 3-Clause "New" or "Revised" License
We plan to add QAT for LLMs to torchao (as mentioned in the original RFC here #47)
For this to run efficiently on the GPU we'd need kernel support for W4A8 quantization (int4 weights, int8 activations).
Other places where this has been raised before
NVIDIA/cutlass#1316,
NVIDIA/cutlass#1370
cc @andrewor14
Last year, we released pytorch-labs/torchao to provide acceleration of Generative AI models using native PyTorch techniques. Torchao added support for running quantization on GPUs, including int8 dynamic quantization (W8A8) and weight-only quantization (int8 and int4) that were composable with torch.compile. Combined, the APIs launched in torchao were able to power SOTA generative AI models across multiple modalities: Segment Anything, Stable Diffusion, and LLaMa.
The results were showcased in these blog posts -
https://pytorch.org/blog/accelerating-generative-ai/,
https://pytorch.org/blog/accelerating-generative-ai-2/,
https://pytorch.org/blog/accelerating-generative-ai-3/
Our investment in torchao is to accelerate Generative AI, using native PyTorch features, ensuring composability with torch.compile.
In 2024, we plan to adopt the following strategy for development of torchao
Let’s dive deeper into some of the coverage areas mentioned above.
Dtypes like NF4, MX4, groupwise quantized int4 are used for implementing various optimization techniques in the models. Last year, we posted a plan on how we wish to support these dtypes in PyTorch. In torchao, we will host tensor subclass based implementation of dtypes, existing examples include uint4 and NF4 that users can use for their own quantization techniques or override the implementation to support other dtypes that might be useful.
Moreover, users don’t need to write triton or cuda kernels for their custom dtypes. The implementation can be in python and torch.compile will take care of generating performant kernels under the hood.
Quantization can be done on only weights or weights+activations. Typically LLM quantization techniques for BS 1 (memory BW bound) use weight-only quantization techniques. But for larger batch sizes, or longer context length cases or for general throughput bound models quantizing the activations is also beneficial. Quantization, however, impacts the model accuracy and researchers have published techniques to mitigate this accuracy impact which currently exist externally as one repository per technique.
In torchao, we will plan to support the following class of techniques using PyTorch, made available via a simple UX and following the one-file-per-technique principle.
LLM weight only quantization techniques
Post training quantization
The two most popular techniques externally are GTPQ and AWQ, available via AutoGPTQ and AutoAWQ which include the technique as well as the performant kernels for faster quantization ops.
To that end, we will start by re-implementing the GPTQ and AWQ techniques into torchao using PyTorch via a simple/intuitive UX that supports saving/loading of quantized models, while realizing the memory savings on disk. Some open questions we need to address here include -
How much VRAM will be required for different quantization techniques
How do we convert to-from weights quantized for different backends (cpu and gpu today use different weight packing format)
In the future, as more interesting and cutting edge techniques are introduced, researchers can directly implement them in torchao or our team can re-implement them in PyTorch.
Weight and activation quantization techniques
Post training quantization
We’ve already implemented W8A8 quantization via the int_mm kernel in core. This has shown speedup on models like SAM, SDXL without any impact to model accuracy and can be turned on via a simple one-line UX implemented via module swap or tensor subclass.
However the challenge here is that some smaller layer shapes might not benefit from quantization due to the overhead in quantizing and dequantizing the activation tensors. Users can either statically ignore quantizing these layers or have a higher level API that figures out which layers are sensitive to quantization. We plan to provide a higher level API via the auto quantizer that applies this technique to the layers that stand to benefit the most to provide the benefits of quantization without having to worry too much about the configs to use.
Quantization aware training
Techniques here require access to fine-tuning, to tune the model to reduce accuracy impact of quantization. Recently, research like LLM-QAT is promising, showing that we can go down to W4A8 and 4-bit KV cache for LLMs. Moreover, newer lower bit techniques like AQLM, Quip# also include a component of fine-tuning to improve the model accuracy.
We will include the APIs and workflow to enable users to do QAT on LLMs, starting with implementing the LLM-QAT paper in torchao and further extending it to support other dtypes like MX4.
Kernels
Optimized kernels are key to making models run faster during inference. Today, in core we already have performant kernels like int_mm
and 4-bit weight quantization kernels for cpu (via intel) and gpu (via tinygemm). torchao will host performant kernels that will work with different backends with a guide on how to plug in these kernels into PyTorch models via the custom ops API. These kernels will compose with torch.compile, with the expectation that the user is expected to write a meta kernel implementation for this. For executorch, the expectation is that if the user provides a kernel that works with executorch then it should also work in eager mode.
We will also directly engage with the community, to upstream their performant kernels into torchao.
Autotuner
In order to use any CUDA kernel efficiently, we'll need to pick the right kernel hyperparameters. For an eager mode kernel, the same is true as well. A kernel autotuner will help here. We expect that the auto quantizer along with the kernel autotuner will make int8 dynamic quantization and int8/int4 weight-only quantization more usable and performant. A WIP example of what this might look like can be found here.
Release engineering
Shipping optimized, custom kernels requires extensibility mechanisms and release channels. We have custom operator support that integrates broadly, but our release mechanism might need to be optimized. It can be quite difficult to ship custom binaries across a broad range of operating systems and accelerators.
We can add a conversion util from popular model storage formats like gguf into PyTorch’s state_dict format. This will enable users to take a pre-existing quantized model from llama.cpp and have it run via PyTorch eager mode for desktop cpu/gpu and executorch for on-device cases. We’ll share more details here soon.
In addition to quantization, we’ve seen promising results with sparsity as well on GPUs. We will share more updates on what torchao will host for the space of sparsity/pruning in the near future.
We'd love to hear any feedback or questions from the OSS community on this RFC. Thank you!
cc @msaroufim @cpuhrsch @jerryzh168 @HDCharles @andrewor14 @jcaip @jisaacso
import torch
import torchvision.models.vision_transformer as models
# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)
import torchao
model.eval().cuda().to(torch.bfloat16)
from torchao.quantization import apply_dynamic_quant
apply_dynamic_quant(model)
from torch._inductor import config as inductorconfig
inductorconfig.force_fuse_int_mm_with_mul = True
model = torch.compile(model, mode='max-autotune')
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')
model(input_tensor)
causes crash
[...]
self.out_proj.weight,
File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1704, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___encoder_layers_encoder_layer_0_self_attention(*(FakeTensor(..., device='cuda:0', size=(1, 197, 768), dtype=torch.bfloat16,
grad_fn=<NativeLayerNormBackward0>), FakeTensor(..., device='cuda:0', size=(1, 197, 768), dtype=torch.bfloat16,
grad_fn=<NativeLayerNormBackward0>), FakeTensor(..., device='cuda:0', size=(1, 197, 768), dtype=torch.bfloat16,
grad_fn=<NativeLayerNormBackward0>)), **{'need_weights': False}):
'DynamicallyPerAxisQuantizedLinear' object has no attribute 'weight'
from user code:
File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 298, in forward
x = self.encoder(x)
File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 157, in forward
return self.ln(self.layers(self.dropout(input)))
File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/cpuhrsch/miniconda3/envs/nightly20240318py310/lib/python3.10/site-packages/torchvision/models/vision_transformer.py", line 113, in forward
x, _ = self.self_attention(x, x, x, need_weights=False)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
This issue tracks outstanding issues for a torchao 0.1 release
New Functionality
Tutorials/BE
If time permits (or v0.2)
This issue tracks outstanding feature requests for torchao. If you'd like a specific feature to be added to torchao, please comment directly here.
Quantization Techniques (based on planned, new requests)
DTypes
Sparsity APIs
Kernels
cc @cpuhrsch
when use new sdpa,why does it dont need to be trained,now paras is added ?
Traceback (most recent call last):
File "C:/Program Files/JetBrains/PyCharm Community Edition 2023.2.1/plugins/python-ce/helpers/pydev/pydevd.py", line 1527, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "C:\Program Files\JetBrains\PyCharm Community Edition 2023.2.1\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "C:\code\foobar\scripts\quantize.py", line 37, in <module>
swap_linear_with_smooth_fq_linear(model)
File "C:\code\py-envs\foobar\lib\site-packages\torchao\quantization\smoothquant.py", line 219, in swap_linear_with_smooth_fq_linear
swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha)
File "C:\code\py-envs\foobar\lib\site-packages\torchao\quantization\smoothquant.py", line 219, in swap_linear_with_smooth_fq_linear
swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha)
File "C:\code\py-envs\foobar\lib\site-packages\torchao\quantization\smoothquant.py", line 219, in swap_linear_with_smooth_fq_linear
swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha)
[Previous line repeated 1 more time]
File "C:\code\py-envs\foobar\lib\site-packages\torchao\quantization\smoothquant.py", line 215, in swap_linear_with_smooth_fq_linear
target_cls = source_cls_to_target_cls[type(child)]
KeyError: <class 'torch.nn.modules.linear.NonDynamicallyQuantizableLinear'>
python-BaseException
Expected: NonDynamicallyQuantizableLinear
layer is skipped (possibly with a warning), or properly handled.
Actual: exception.
It sounds like HDCharles was planning on fixing this more generally: pytorch/pytorch#58969
Traceback (most recent call last):
File "C:\code\foo\scripts\quantize.py", line 4, in <module>
from torchao.quantization.smoothquant import (
File "C:\code\py-envs\foo\lib\site-packages\torchao\quantization\__init__.py", line 7, in <module>
from .smoothquant import * # noqa: F403
File "C:\code\py-envs\foo\lib\site-packages\torchao\quantization\smoothquant.py", line 17, in <module>
import torchao.quantization.quant_api as quant_api
File "C:\code\py-envs\foo\lib\site-packages\torchao\quantization\quant_api.py", line 18, in <module>
from .subclass import (
File "C:\code\py-envs\foo\lib\site-packages\torchao\quantization\subclass.py", line 13, in <module>
from torch.utils._python_dispatch import return_and_correct_aliasing
ImportError: cannot import name 'return_and_correct_aliasing' from 'torch.utils._python_dispatch' (C:\code\py-envs\foo\lib\site-packages\torch\utils\_python_dispatch.py)
This project seems to rely on torch nightly, which exports return_and_correct_aliasing
. It might be worthwhile to document this. I suppose one could argue it's obvious enough from this being an experimental repo, but it was surprising to me.
Nice work team, I'm looking forward to using this package.
The code is out, it's quite simple and short
Opening this so I can track how to add this to ao and make sure it works well with torch.compile(). This will likely need blackwell to perform decently
https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
Does Torch already support int8 mm on cuda? Also, what version of torch can run torch now?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.