Git Product home page Git Product logo

Comments (7)

Xia-Weiwen avatar Xia-Weiwen commented on May 24, 2024 1

Hi @RazeBerry How about adding the following before get_default_qconfig?

torch.backends.quantized.engine = 'qnnpack'

from pytorch.

malfet avatar malfet commented on May 24, 2024

fbgemm was always too x86 oriented, not sure what's the thing with qnnpack though:

% python -c "import torch;print(torch.backends.quantized.supported_engines)"
['qnnpack', 'none']

from pytorch.

RazeBerry avatar RazeBerry commented on May 24, 2024

fbgemm was always too x86 oriented, not sure what's the thing with qnnpack though:

% python -c "import torch;print(torch.backends.quantized.supported_engines)"
['qnnpack', 'none']

Sorry about the error in the code, even if I change it to: qconfig = get_default_qconfig('qnnpack') . The issue persists.

Furthermore, qnnpack is there:

(base) sihao@Sihaos-MacBook-Pro-2 documents % python -c "import torch;print(torch.backends.quantized.supported_engines)"
['qnnpack', 'none']

from pytorch.

malfet avatar malfet commented on May 24, 2024

@RazeBerry is @Xia-Weiwen suggested, specifying quantized engine should fix your problem, though I agree that qnnpack backend (as the only one available on ARM platform) should have been selected by default

from pytorch.

RazeBerry avatar RazeBerry commented on May 24, 2024

@Xia-Weiwen @malfet

That works! Thank you so much !! It is a bit confusing it has to be intentionally specified at all considering QNN is the only one available on ARM. It would be great if the package can be modified so one doesn't need that line. I am using many packages with pytorch as dependencies and I almost always have problem quantization due to the fact that line is not specified in those packages.

from pytorch.

RazeBerry avatar RazeBerry commented on May 24, 2024

@Xia-Weiwen @malfet Just one last question, I was wondering if this is were to be expected running it on ARM architecture, have seen Whisper throwing the similar type of error due to non-implementation? Thank you very much

import torch
from torch import nn
import torch.nn.quantized as nnq
from torch.quantization import get_default_qconfig, prepare, convert

# Define the model
class SimpleLinearModel(nn.Module):
    def __init__(self):
        super(SimpleLinearModel, self).__init__()
        self.linear = nn.Linear(5, 10)  # Example dimensions

    def forward(self, x):
        return self.linear(x)

# Instantiate the model
model = SimpleLinearModel()

torch.backends.quantized.engine = 'qnnpack'

# Define the qconfig (using 'fbgemm' or 'qnnpack' configuration)
qconfig = get_default_qconfig('qnnpack')  # or 'fbgemm'

# Apply the qconfig to the model
model.qconfig = qconfig

# Prepare the model for quantization
model.eval()  # Set the model to evaluation mode
prepared_model = prepare(model)

# Calibrate the model with sample data
calibration_data = torch.randn(64, 5)  # Generate sample data for calibration
prepared_model(calibration_data)

# Convert the prepared model to a quantized model
quantized_model = convert(prepared_model)

# Test the quantized model
input_data = torch.randn(1, 5)
output = quantized_model(input_data)
print(output)
Traceback (most recent call last):
  File "/Users/sihao/Documents/errorreproducer.py", line 39, in <module>
    output = quantized_model(input_data)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/Documents/errorreproducer.py", line 13, in forward
    return self.linear(x)
           ^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/linear.py", line 168, in forward
    return torch.ops.quantized.linear(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sihao/.pyenv/versions/3.11.0/lib/python3.11/site-packages/torch/_ops.py", line 755, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear' is only available for these backends: [MPS, Meta, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

MPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:75 [backend fallback]
Meta: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/quantized/cpu/qlinear.cpp:1140 [kernel]
BackendSelect: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradMeta: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/TraceTypeManual.cpp:297 [backend fallback]
AutocastCPU: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastCUDA: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:158 [backend fallback]

from pytorch.

Xia-Weiwen avatar Xia-Weiwen commented on May 24, 2024

@RazeBerry You are using eager mode quantization in your script. In this mode, you need to insert QuantStub and DeQuantStub in your model to quantize input and dequantize output. See doc here: https://pytorch.org/docs/stable/quantization.html#post-training-static-quantization
Alternatively, You may use FX mode so that quant/dequant are inserted automatically and you don't have to insert them yourself: https://pytorch.org/docs/stable/quantization.html#prototype-maintaince-mode-fx-graph-mode-quantization

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.