Comments (5)
Hi @ecilay, please try the following code to make export work. Changes include a) providing a tuple input to capture_pre_autograd_graph
and b) fixing the import path for XNNPACKQuantizer
.
def export_quantize(m):
# Issue: bn expected 4d, got 3
from torch._export import capture_pre_autograd_graph
example_inputs = (torch.randn(1, 3, 224, 224),) # Note: input should be a tuple
# breakpoint()
m = capture_pre_autograd_graph(m, example_inputs)
# we get a model with aten ops
# Step 2. quantization
from torch.ao.quantization.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import ( # Note: Updated import path
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_pt2e(m, quantizer)
# calibration omitted
m = convert_pt2e(m)
# we have a model with aten ops doing integer computations when possible
return m
from pytorch.
Thanks your fix works, however, the resnet classification results are totally wrong. You can reproduce using below inference code. Also the runtime expoerted in this way is almost doubled/tripled: pt_time: 0.0398 vs quantize_time: 0.0918
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
img = read_image("crane.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
# model = resnet50(weights=weights)
# model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
def run_inference(model):
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
from pytorch.
Besides, if I use fx export, the outputs are numerically very different from original model; though the softmax classification results are the same, is this expected?
from pytorch.
Do you mind share "crane.jpg" and the log? Thanks!
from pytorch.
The outputs are below between original and export quantized model:
geyser: 5.9%
iron: 99.6%
from pytorch.
Related Issues (20)
- [inductor][cpu]mobilenet_v3_large PTQ/QAT performance regression in 2024-05-04 nightly release
- Aborted (core dumped) in torch.fft.irfftn/hfftn/ihfftn with malloc(): corrupted top size HOT 1
- DISABLED test_inline_dict_function_cpp_guard_manager (__main__.CppGuardManagerMiscTests) HOT 2
- DISABLED test_quantization_doc_ptdq (__main__.TestQuantizationDocs) HOT 1
- DISABLED test_quantization_doc_custom (__main__.TestQuantizationDocs) HOT 1
- DISABLED test_quantization_doc_ptsq (__main__.TestQuantizationDocs) HOT 1
- DISABLED test_quantization_doc_fx (__main__.TestQuantizationDocs) HOT 2
- DISABLED test_view_and_inplace_view (__main__.TestAOTAutograd) HOT 1
- [inductor][cpu]mobilenet_v2_quantized_qat float32 single thread static/dynamic shape CPP/default wrapper performance regression in 2024-04-28 nightly release
- [BUG]Nan in gradients of scaled_dot_product_attention operation with mem_efficient backend
- Unnecessary warning when numpy not installed
- [RFC] Add Cpp Template for GEMM related ops via max-autotune for Inductor CPU
- MAX-Autotune Compilation Time Regression Due To Added MM Configs HOT 1
- cnm
- DISABLED [WORKFLOW_NAME] / [PLATFORM_NAME] / [JOB_NAME] HOT 1
- cnm
- [Dynamo] Support tracing through _get_current_dispatch_mode_stack
- Have config/env option to disable all PT2 caching
- [dynamo] fix nn.Module @property that accesses closure cells
- KINETO_USE_DAEMON causing issues
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.