Git Product home page Git Product logo

Comments (7)

yf711 avatar yf711 commented on August 20, 2024

Hi @MiroPsota Thanks for bringing up this issue!
Does the model with NMS op could run on previous version of ONNXRuntime+TRT ?
Also could you share the standard model (without nms) that you tested on CPU/GPU/OpenVINO EP?

from onnxruntime.

MiroPsota avatar MiroPsota commented on August 20, 2024

Zip with all the models and updated run.py.

1.18.0 and 1.18.1 - the mentioned problem occurs.

I used TensorRT 8.6.1.6 from here for 1.17.1 tests (ORT gpu pypi package).
OOM doesn't occur, but another error occurs (op not implemented) and it adds not wanted copy operations to a host and back. See the log. I will investigate further.

The ONNX model for TensorRT can be run without problems with mmdeploy, which uses TensorRT directly. One difference is that TRTBatchedNMS had originally a different domain, which I changed to trt.plugins according to docs, so it can be run in ORT.

from onnxruntime.

MiroPsota avatar MiroPsota commented on August 20, 2024

I think the problem is with TensorRT 10.0, probably the same as mentioned here.

I solved it with compiling ORT 1.18.1 with TRT 8.6.1 and --use_tensorrt_oss_parser. More info here.

from onnxruntime.

chilo-ms avatar chilo-ms commented on August 20, 2024

@MiroPsota

What your tested with following combinations are expected:

  • ORT 1.18.1 + TRT 10.0 (use built-in parser by default) ---> OOM
  • ORT 1.17.1 + TRT 8.6 (use built-in parser by default) ---> TRTBatchedNMS can't be recognized by TRT
  • ORT 1.18.1 + TRT 8.6 + TRT OSS parser ---> Can successfully run the inference

For TRT 10.0, i tested with trtexec and got following error, but i didn't see the error with TRT 8.6 trtexec:
(It seems TRT 10 has some issues and i'm reporting this to Nvidia. )

...
[07/09/2024-17:06:31] [I] [TRT] ----------------------------------------------------------------
[07/09/2024-17:06:31] [I] [TRT] No checker registered for op: TRTBatchedNMS. Attempting to check as plugin.
[07/09/2024-17:06:31] [I] [TRT] No importer registered for op: TRTBatchedNMS. Attempting to import as plugin.
[07/09/2024-17:06:31] [I] [TRT] Searching for plugin: TRTBatchedNMS, plugin_version: 1, plugin_namespace:
[07/09/2024-17:06:31] [I] [TRT] Successfully created plugin: TRTBatchedNMS
[07/09/2024-17:06:31] [W] [TRT] IElementWiseLayer with inputs /TRTBatchedNMS_output_2 and /Mul_output_0: first input has type Int32 but second input has type Int64.
[07/09/2024-17:06:31] [E] Error[4]: ITensor::getDimensions: Error Code 4: Internal Error (/Add_2: IElementWiseLayer with SUM operation has incompatible input types Int32 and Int64 type.)
[07/09/2024-17:06:31] [E] [TRT] ModelImporter.cpp:949: While parsing node number 434 [Add -> "/Add_2_output_0"]:
[07/09/2024-17:06:31] [E] [TRT] ModelImporter.cpp:950: --- Begin node ---
input: "/TRTBatchedNMS_output_2"
input: "/Mul_output_0"
output: "/Add_2_output_0"
name: "/Add_2"
op_type: "Add"

[07/09/2024-17:06:31] [E] [TRT] ModelImporter.cpp:951: --- End node ---
[07/09/2024-17:06:31] [E] [TRT] ModelImporter.cpp:954: ERROR: ModelImporter.cpp:195 In function parseNode:
[6] Invalid Node - /Add_2
ITensor::getDimensions: Error Code 4: Internal Error (/Add_2: IElementWiseLayer with SUM operation has incompatible input types Int32 and Int64 type.)
...

As for TRT 8.6, user can only use ORT TRT with TRT OSS parser to successfully register and run the customed TRT plugins (in this case, the TRTBatchedNMS op). But for TRT 10, ORT + both built-in parser and OSS parser can successfully register the customed TRT plugins.

from onnxruntime.

chilo-ms avatar chilo-ms commented on August 20, 2024

it adds not wanted copy operations to a host and back. See the log. I will investigate further.

The copy ops being added is expected. It's because, TRTBatchedNMS op is not recognized by TRT and ORT will try to put TRTBatchedNMS op on other EPs which will result in the copy ops being added. But still, no other EPs can run TRTBatchedNMS ops, so it end up with the error you saw

from onnxruntime.

chilo-ms avatar chilo-ms commented on August 20, 2024

The OOM you saw with TRT 10 is because TRT EP is calling TRT parser to check the TRT eligible nodes in the model where onnxruntime::TensorrtExecutionProvider::GetSupportedList() is being called many times recursively and hit OOM.
I'm also checking with Nvidia regarding this.

from onnxruntime.

chilo-ms avatar chilo-ms commented on August 20, 2024

@MiroPsota

Nvidia helped identify the issue.
TRT 10.0 added support for INT64 types, whereas TensorRT 8.6 treated all INT64-typed tensors as INT32. The TRTBatchedNMS plugin internally always produces INT32 output tensors, even though in the ONNX graph they’re marked as INT64:
image
This causes the build issues in standalone TensorRT 10, as the subsequent operations (i.e. the Add) assumes INT64-typed operands when in reality there is a INT32 + INT64 mix. This issue doesn’t exist in TensorRT 8.6 as everything is casted to INT32, and the types match.

Update the TRTBatchedNMS plugin ( inlibmmdeploy_op.so) to produce proper INT64 output tensors should fix the issue.

from onnxruntime.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.