Closed peri044 closed 3 months ago
Hello, I want to use torch_tensorrt to accelerate the vit model, but I encountered the following problem. It seems that there are still some problems with the expand operator?
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: [SLICE]-[aten_ops.expand.default]-[/vit_embeddings/expand]: ISliceLayer has out of bounds access on axis 0 Condition '<' violated: 3 >= 1.)
Traceback (most recent call last):
File "/mnt/bn/hukongtao-infer-speed/mlx/users/kongtao.hu/codebase/EasyGuard_0617/speed_vit_test.py", line 59, in <module>
trt_gm = torch_tensorrt.compile(
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/_compile.py", line 250, in compile
trt_graph_module = dynamo_compile(
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 243, in compile
trt_gm = compile_module(gm, inputs, settings)
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 431, in compile_module
trt_module = convert_module(
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 107, in convert_module
interpreter_result = interpret_module_to_result(module, inputs, settings)
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 88, in interpret_module_to_result
interpreter_result = interpreter.run()
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 350, in run
assert serialized_engine
AssertionError
my code:
import time
import torch
import torch_tensorrt
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
from torchvision.models import ResNet50_Weights, resnet50
from transformers.trainer_utils import set_seed
set_seed(0)
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = model.eval().cuda()
print(type(model))
inputs = [
torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[4, 3, 224, 224],
max_shape=[8, 3, 224, 224],
dtype=torch.float32
)
]
trt_gm = torch_tensorrt.compile(
model, "dynamo", inputs, enabled_precisions={torch.float, torch.half}
)
Description
This PR fixes bugs in expand converter when the input is dynamic
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: