NVIDIA / trt-samples-for-hackathon-cn

Simple samples for TensorRT programming
Apache License 2.0
1.47k stars 337 forks source link

05-parser loadNpz test failed on nvcr.io/nvidia/pytorch:22.09-py3 #67

Closed jackzhou121 closed 1 year ago

jackzhou121 commented 1 year ago

After install packages in requrements.txt, i test loadNpz plugin and failed. err msg: [03/06/2023-09:33:00] [TRT] [E] 2: [stdArchiveReader.h::readManyHelper::333] Error Code 2: Internal Error (Assertion prefix.count failed. Enums must always have at least one entry.) [03/06/2023-09:33:00] [TRT] [E] 4: [runtime.cpp::deserializeCudaEngine::66] Error Code 4: Internal Error (Engine deserialization failed.)

wili-65535 commented 1 year ago

Thank for finding this problem! In recent TensorRT, Plugin Layer needs at least one input tensor (even if a dummy one), but the example does not. The problem is fixed by adding one input tensor and it surely works now.

jackzhou121 commented 1 year ago

hi, when i build tensorrt fp16 network, why all input dtype are set to trt.float32 not trt,.float16?

wili-65535 commented 1 year ago

When using FP32 as input data type, the implicit convert from FP32 to FP16 will be included in the TensorRT engine.

You can also directly use FP16 as input data type if you need, just change the parameter in "set_input" API.

jackzhou121 commented 1 year ago

When I build fp16 network, if input dtype is trt.float32, and all inputs dtype are np.float15, i found enqueu func use fp32 to compute not fp16, so how to understand implicit convert? wili @.***>于2023年3月28日 周二12:36写道:

When using FP32 as input data type, the implicit convert from FP32 to FP16 will be included in the TensorRT engine.

You can also directly use FP16 as input data type if you need, just change the parameter in "set_input" API.

— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/trt-samples-for-hackathon-cn/issues/67#issuecomment-1486202666, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADHQULIMLYYNTU7H5HEZSTDW6JTDTANCNFSM6AAAAAAVQ4JJYA . You are receiving this because you modified the open/close state.Message ID: @.***>

wili-65535 commented 1 year ago

What do you mean by "if input dtype is trt.float32, and all inputs dtype are np.float15"?

Anyway, run the example code below to check where your problem is. You should receive the VERBOSE log like:

...

This information below tells us the convolution layer computes in float16.

[03/28/2023-04:57:55] [TRT] [V] Engine Layer Information: Layer(CaskConvolution): (Unnamed Layer 0) [Convolution], Tactic: 0x1a44b86b83a2d99d, inputT0 (Half[1,1,6,9]) -> (Unnamed Layer 0) [Convolution]_output (Half[1,1,4,7]) ...

This information below tells us the data type of input / output tensors are float16.

[ 0]Input -> DataType.HALF (1, 1, 6, 9) (1, 1, 6, 9) inputT0 [ 1]Output-> DataType.HALF (1, 1, 4, 7) (1, 1, 4, 7) (Unnamed Layer* 0) [Convolution]_output ...

#
# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
from cuda import cudart
import tensorrt as trt

nB, nC, nH, nW = 1, 1, 6, 9
nCOut, nKernelHeight, nKernelWidth = 1, 3, 3
data = np.ones([nB, nC, nH, nW],dtype=np.float16)

weight = np.ones([nCOut, nC, nKernelHeight, nKernelWidth], dtype=np.float32)
bias = np.ascontiguousarray(np.zeros(nCOut, dtype=np.float32))

np.set_printoptions(precision=3, linewidth=200, suppress=True)
cudart.cudaDeviceSynchronize()

logger = trt.Logger(trt.Logger.VERBOSE)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
inputT0 = network.add_input("inputT0", trt.float16, (nB, nC, nH, nW))
#------------------------------------------------------------------------------- Network
convolutionLayer = network.add_convolution_nd(inputT0, nCOut, (nKernelHeight, nKernelWidth), trt.Weights(weight), trt.Weights(bias))
convolutionLayer.get_output(0).dtype= trt.float16
#------------------------------------------------------------------------------- Network
network.mark_output(convolutionLayer.get_output(0))
engineString = builder.build_serialized_network(network, config)
engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
nIO = engine.num_io_tensors
lTensorName = [engine.get_tensor_name(i) for i in range(nIO)]
nInput = [engine.get_tensor_mode(lTensorName[i]) for i in range(nIO)].count(trt.TensorIOMode.INPUT)

context = engine.create_execution_context()
for i in range(nIO):
    print("[%2d]%s->" % (i, "Input " if i < nInput else "Output"), engine.get_tensor_dtype(lTensorName[i]), engine.get_tensor_shape(lTensorName[i]), context.get_tensor_shape(lTensorName[i]), lTensorName[i])

bufferH = []
bufferH.append(np.ascontiguousarray(data))
for i in range(nInput, nIO):
    bufferH.append(np.empty(context.get_tensor_shape(lTensorName[i]), dtype=trt.nptype(engine.get_tensor_dtype(lTensorName[i]))))
bufferD = []
for i in range(nIO):
    bufferD.append(cudart.cudaMalloc(bufferH[i].nbytes)[1])

for i in range(nInput):
    cudart.cudaMemcpy(bufferD[i], bufferH[i].ctypes.data, bufferH[i].nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)

for i in range(nIO):
    context.set_tensor_address(lTensorName[i], int(bufferD[i]))

context.execute_async_v3(0)

for i in range(nInput, nIO):
    cudart.cudaMemcpy(bufferH[i].ctypes.data, bufferD[i], bufferH[i].nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost)

for i in range(nIO):
    print(lTensorName[i])
    print(bufferH[i])

for b in bufferD:
    cudart.cudaFree(b)