pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 351 forks source link

❓ [Question] When using torch_tensorrt.compile to optimize Mask2Former's multi_scale_deformable_attn layer, an error occurs. #3098

Open edition3234 opened 3 months ago

edition3234 commented 3 months ago

❓ Question

I was preparing to export a TRT model for Mask2Former using the command optimized_model = torch_tensorrt.compile(model, inputs=imgs, enabled_precisions={torch.half}), where model is a Mask2Former loaded through mmseg. However, I encountered an error at the line valuel = value_list[0].flatten(2).transpose(1, 2).reshape(4 * 8, 32, 16, 16)*: The error message was: `"Failed running call_method reshape((FakeTensor(..., device='cuda:0', size=(1, 256, 256), grad_fn=), 32, 32, 16, 16), {}): shape '[32, 32, 16, 16]' is invalid for input of size 65536"`

The original code was *valuel = value_list[level].flatten(2).transpose(1, 2).reshape(bs num_heads, embeddims, H, W_). Even after fixing all variables with constants, During training, this can be reshaped normally**, but the above error occurs when using torch_tensorrt.compile.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

The complete code is as follows:

    value_list = value.split([16*16,32*32,64*64], dim=1)
    value_l_ = value_list[0].flatten(2).transpose(1, 2).reshape(4 * 8, 32, 16, 16)
    sampling_grid_l_ = sampling_grids[:, :, :,0].transpose(1, 2).flatten(0, 1)
    sampling_value_l_ = F.grid_sample(
            value_l_,
            sampling_grid_l_,
            mode='bilinear',
            padding_mode='zeros',
            align_corners=False)
    sampling_value_list.append(sampling_value_l_)
apbose commented 1 month ago

@edition3234 what is the value you are passing? The error seems to be mismatch in the input dimension and the reshape dimension you want. Could you please provide a simple repro example?