grimoire / torch2trt_dynamic

A pytorch to tensorrt convert with dynamic shape support
MIT License
254 stars 34 forks source link

fix interpolate dynamic input #3

Open huliang2016 opened 3 years ago

huliang2016 commented 3 years ago

2

grimoire commented 3 years ago

Hi
Thanks for the bug report. The dynamic shape support of interpolate is a little bit ... complex. I warp the Tensor.shape[i] to a IntWarper in size.py to trace the shape data. as for interpolate, if the size contain any IntWarper, is_shape_tensor flag will be set. That means if your model is:

        xh, xw = x.shape[-2] * 2, x.shape[-1] * 2
        x = F.interpolate(x, size=(xh, xw), mode="nearest")

The convert should give you right anwser.
But you have cast it back to python int. That IntWarper can not trace the shape any more. A const value will be used as interpolate shape.
Use scales instead of shape might not be a good idea. because the shape could be:

xh, xw = x.shape[-2]+10, x.shape[-1]-10

I am trying to do something about int cast. If you have any idea. Please share it with me.