NVIDIA-AI-IOT / torch2trt

An easy to use PyTorch to TensorRT converter
MIT License
4.6k stars 675 forks source link

Preceding torch.cuda() operations affect batched results #392

Closed PeterVennerstrom closed 4 years ago

PeterVennerstrom commented 4 years ago

In order for batched inference to work on my machine, it looks like a torch cpu tensor must be placed on gpu via .cuda() directly before inference with model_trt. If a tensor is placed on gpu then manipulated in any pre-processing steps as a cuda tensor, only the results of the first image in the batch, e.g. img[0] where img.shape = [2, 3, 64, 64] are valid in the results.

The model was created with the max_batch_size argument set to 8.

model_trt = torch2trt(model, [x], max_batch_size=8)

Working Example 1:

model_trt = TRTModule() model_trt.load_state_dict(torch.load('resnet18_trt.pth')) img = torch.rand(1, 3, 64, 64) / 2 img = img.expand(2, 3, 64, 64).cuda() r = model_trt(img)

r[0] == r[1]

Working Example 2:

model_trt = TRTModule() model_trt.load_state_dict(torch.load('resnet18_trt.pth')) img = torch.rand(1, 3, 64, 64).cuda() / 2 img = img.cpu() img = img.expand(2, 3, 64, 64).cuda() r = model_trt(img)

r[0] == r[1]

Non-Working Example:

model_trt = TRTModule() model_trt.load_state_dict(torch.load('resnet18_trt.pth')) img = torch.rand(1, 3, 64, 64).cuda() / 2 img = img.expand(2, 3, 64, 64) r = model_trt(img)

r[0] != r[1]

jaybdub commented 4 years ago

Hi @PeterVennerstrom,

Thanks for reaching out!

My initial guess is that the torch.expand operation is actually just modifying the stride of the tensor, but not copying the internal data. If this is the case, all batches other than the first, will not actually have data in memory, but will point to the first. When the data is passed to the TRTModule, it's expected to be contigious.

Do you mind trying

img = img.contiguous()

Before executing the TRTModule?

If this resolves the issue, we should probably handle this case by default inside TRTModule

Please let me know if this helps or you run into issues.

Best, John

PeterVennerstrom commented 4 years ago

Hi John,

Your suggestion to run:

 img = img.contiguous()

resolved the issue.

Thanks for help!