Git Product home page Git Product logo

Comments (4)

kunal-vaishnavi avatar kunal-vaishnavi commented on July 18, 2024 1

The attached Phi-3 ONNX model is not shape inferred for all the operators. Couple of operators might have symbolic shape inferenced with dynamic axes. The vast majority of the operators are not shape inferenced. For example, here is one of the subgraphs of the model that is not shape inferenced, visualized in Netron. From my understanding, this is how Netron visualizes shape inferenced operators after running the model through the SymbolicShapeInference.infer_shapes tool, which I was not able to do for the phi-3 model (this subgraph is from a different onnx model).

You can find the shape inference by clicking on the operator and pressing the '+' icon next to the right of each input name and output name. Here is an example.

image

I do not see MatMulNBits operator in the list of supported operators you shared for the SymbolicShapeInference.infer_shapes tool, which might be a reason why SymbolicShapeInference.infer_shapes tool is giving out the error

Yes, your error occurs because symbolic shape inference for MatMulNBits isn't implemented in SymbolicShapeInference.infer_shapes. We can add MatMulNBits to fix this.

Were you able to successfully shape infer the phi-3 model for all operators? I am not able to do it with release version of onnxruntime 1.18.0. Which version of onnxruntime are you using?

The uploaded Phi-3 ONNX models are created via ONNX Runtime GenAI's model builder. The shape inferences for their operators are created here in the model builder using onnx.helper.make_tensor_value_info and added to the ModelProto here.

from onnxruntime.

tianleiwu avatar tianleiwu commented on July 18, 2024

@kunal-vaishnavi, could you take a look at symbolic shape inference works on phi-3 models.

from onnxruntime.

kunal-vaishnavi avatar kunal-vaishnavi commented on July 18, 2024

The uploaded Phi-3 ONNX models already have been symbolic shape inferenced with dynamic axes.

The symbolic shape inference for most quantization operators is defined in each operator's spec.

ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(MatMulNBits_ver1_doc)
.Attr("K", "size of each input feature", AttributeProto::INT)
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
.Attr("accuracy_level",
"The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) "
"(default unset). It is used to control how input A is quantized or downcast internally while "
"doing computation, for example: 0 means input A will not be quantized or downcast while doing "
"computation. 4 means input A can be quantized with the same block_size to int8 internally from "
"type T1.",
AttributeProto::INT, static_cast<int64_t>(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1 or 2 dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
.Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional)
.Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional)
.Input(5, "bias", "Bias to add to result. It should have shape [N].", "T1", OpSchema::Optional)
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
.TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
.TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.")
.TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.")
.TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
int64_t in_features = getAttribute(ctx, "K", -1);
int64_t out_features = getAttribute(ctx, "N", -1);
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true);
// validate bias shape
if (ctx.hasInput(5)) {
if (!hasInputShape(ctx, 5)) {
fail_shape_inference("bias shape must be known");
}
const auto& bias_shape = getInputShape(ctx, 5);
if (bias_shape.dim_size() != 1 ||
!bias_shape.dim(0).has_dim_value() ||
bias_shape.dim(0).dim_value() != out_features) {
fail_shape_inference("bias shape must be [N] where N = ", out_features);
}
}
});

Here is the list of supported operators whose shapes can be symbolically inferred in the SymbolicShapeInference.infer_shapes tool.

self.dispatcher_ = {
"Add": self._infer_symbolic_compute_ops,
"ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor,
"AveragePool": self._infer_Pool,
"BatchNormalization": self._infer_BatchNormalization,
"Cast": self._infer_Cast,
"CategoryMapper": self._infer_CategoryMapper,
"Compress": self._infer_Compress,
"Concat": self._infer_Concat,
"ConcatFromSequence": self._infer_ConcatFromSequence,
"Constant": self._infer_Constant,
"ConstantOfShape": self._infer_ConstantOfShape,
"Conv": self._infer_Conv,
"CumSum": self._pass_on_shape_and_type,
"Div": self._infer_symbolic_compute_ops,
"Einsum": self._infer_Einsum,
"Expand": self._infer_Expand,
"Equal": self._infer_symbolic_compute_ops,
"Floor": self._infer_symbolic_compute_ops,
"Gather": self._infer_Gather,
"GatherElements": self._infer_GatherElements,
"GatherND": self._infer_GatherND,
"Identity": self._pass_on_shape_and_type,
"AllReduce": self._pass_on_shape_and_type,
"If": self._infer_If,
"Loop": self._infer_Loop,
"MatMul": self._infer_MatMul,
"MatMulInteger16": self._infer_MatMulInteger,
"MaxPool": self._infer_Pool,
"Max": self._infer_symbolic_compute_ops,
"MemcpyFromHost": self._pass_on_shape_and_type,
"MemcpyToHost": self._pass_on_shape_and_type,
"Min": self._infer_symbolic_compute_ops,
"MoE": self._pass_on_shape_and_type,
"Mul": self._infer_symbolic_compute_ops,
"NonMaxSuppression": self._infer_NonMaxSuppression,
"NonZero": self._infer_NonZero,
"OneHot": self._infer_OneHot,
"Pad": self._infer_Pad,
"Range": self._infer_Range,
"Reciprocal": self._pass_on_shape_and_type,
"ReduceSum": self._infer_ReduceSum,
"ReduceProd": self._infer_ReduceProd,
"Reshape": self._infer_Reshape,
"Resize": self._infer_Resize,
"Round": self._pass_on_shape_and_type,
"Scan": self._infer_Scan,
"ScatterElements": self._infer_ScatterElements,
"SequenceAt": self._infer_SequenceAt,
"SequenceInsert": self._infer_SequenceInsert,
"Shape": self._infer_Shape,
"Size": self._infer_Size,
"Slice": self._infer_Slice,
"SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss,
"SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss,
"NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss,
"Split": self._infer_Split,
"SplitToSequence": self._infer_SplitToSequence,
"Squeeze": self._infer_Squeeze,
"Sub": self._infer_symbolic_compute_ops,
"Tile": self._infer_Tile,
"TopK": self._infer_TopK,
"Transpose": self._infer_Transpose,
"Unsqueeze": self._infer_Unsqueeze,
"Where": self._infer_symbolic_compute_ops,
"ZipMap": self._infer_ZipMap,
"Neg": self._infer_symbolic_compute_ops,
# contrib ops:
"Attention": self._infer_Attention,
"BiasAdd": self._infer_BiasAdd,
"BiasGelu": self._infer_BiasGelu,
"BiasSplitGelu": self._infer_BiasSplitGelu,
"DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
"DequantizeLinear": self._infer_DequantizeLinear,
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
"Gelu": self._infer_Gelu,
"GemmFastGelu": self._infer_GemmFastGelu,
"GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"GroupQueryAttention": self._infer_GroupQueryAttention,
"SparseAttention": self._infer_SparseAttention,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
"MultiHeadAttention": self._infer_MultiHeadAttention,
"NhwcConv": self._infer_NhwcConv,
"PackedAttention": self._infer_PackedAttention,
"PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
"PagedAttention": self._infer_PagedAttention,
"PythonOp": self._infer_PythonOp,
"QuantizeLinear": self._infer_QuantizeLinear,
"QuickGelu": self._infer_FastGelu,
"RelativePositionBias": self._infer_RelativePositionBias,
"RemovePadding": self._infer_RemovePadding,
"RestorePadding": self._infer_RestorePadding,
"RotaryEmbedding": self._infer_RotaryEmbedding,
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
}
self.aten_op_dispatcher_ = {
"embedding": self._infer_Gather,
"bitwise_or": self._infer_aten_bitwise_or,
"diagonal": self._infer_aten_diagonal,
"max_pool2d_with_indices": self._infer_aten_pool2d,
"max": self._infer_aten_minmax,
"min": self._infer_aten_minmax,
"multinomial": self._infer_aten_multinomial,
"unfold": self._infer_aten_unfold,
"argmax": self._infer_aten_argmax,
"avg_pool2d": self._infer_aten_pool2d,
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
"numpy_T": self._infer_Transpose,
"native_group_norm": self._infer_aten_group_norm,
"upsample_nearest1d": self._infer_aten_upsample,
"upsample_nearest2d": self._infer_aten_upsample,
"upsample_nearest3d": self._infer_aten_upsample,
"upsample_bicubic2d": self._infer_aten_upsample,
}

from onnxruntime.

shamith2 avatar shamith2 commented on July 18, 2024

@kunal-vaishnavi, thanks for the response. I have a few questions and comments from my side:

  1. The attached Phi-3 ONNX model is not shape inferred for all the operators. Couple of operators might have symbolic shape inferenced with dynamic axes. The vast majority of the operators are not shape inferenced. For example, here is one of the subgraphs of the model that is not shape inferenced, visualized in Netron:

graph

From my understanding, this is how Netron visualizes shape inferenced operators after running the model through the SymbolicShapeInference.infer_shapes tool, which I was not able to do for the phi-3 model (this subgraph is from a different onnx model):

graph_inf

  1. I do not see MatMulNBits operator in the list of supported operators you shared for the SymbolicShapeInference.infer_shapes tool, which might be a reason why SymbolicShapeInference.infer_shapes tool is giving out the error

  2. Were you able to successfully shape infer the phi-3 model for all operators? I am not able to do it with release version of onnxruntime 1.18.0. Which version of onnxruntime are you using?

from onnxruntime.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.