Git Product home page Git Product logo

Comments (5)

BoyuanFeng avatar BoyuanFeng commented on May 12, 2024 1

Hi @ecilay, please try the following code to make export work. Changes include a) providing a tuple input to capture_pre_autograd_graph and b) fixing the import path for XNNPACKQuantizer.

def export_quantize(m):
    # Issue: bn expected 4d, got 3
    from torch._export import capture_pre_autograd_graph
    example_inputs = (torch.randn(1, 3, 224, 224),) # Note: input should be a tuple
    # breakpoint()
    m = capture_pre_autograd_graph(m, example_inputs)
    # we get a model with aten ops


    # Step 2. quantization
    from torch.ao.quantization.quantize_pt2e import (
    prepare_pt2e,
    convert_pt2e,
    )

    from torch.ao.quantization.quantizer.xnnpack_quantizer import ( # Note: Updated import path
        XNNPACKQuantizer,
        get_symmetric_quantization_config,
    )
    # backend developer will write their own Quantizer and expose methods to allow
    # users to express how they
    # want the model to be quantized
    quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
    m = prepare_pt2e(m, quantizer)

    # calibration omitted

    m = convert_pt2e(m)
    # we have a model with aten ops doing integer computations when possible

    return m

from pytorch.

ecilay avatar ecilay commented on May 12, 2024

Thanks your fix works, however, the resnet classification results are totally wrong. You can reproduce using below inference code. Also the runtime expoerted in this way is almost doubled/tripled: pt_time: 0.0398 vs quantize_time: 0.0918

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("crane.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
# model = resnet50(weights=weights)
# model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()



def run_inference(model):
    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score:.1f}%")

from pytorch.

ecilay avatar ecilay commented on May 12, 2024

Besides, if I use fx export, the outputs are numerically very different from original model; though the softmax classification results are the same, is this expected?

from pytorch.

BoyuanFeng avatar BoyuanFeng commented on May 12, 2024

Do you mind share "crane.jpg" and the log? Thanks!

from pytorch.

ecilay avatar ecilay commented on May 12, 2024

crane
The outputs are below between original and export quantized model:

geyser: 5.9%
iron: 99.6%

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.