styler00dollar / VSGAN-tensorrt-docker

Using VapourSynth with super resolution and interpolation models and speeding them up with TensorRT.
BSD 3-Clause "New" or "Revised" License
274 stars 30 forks source link

is there a way to use rife tensorrt engine without vapoursynth? #75

Closed yuvraj108c closed 3 weeks ago

yuvraj108c commented 3 weeks ago

Referencing the pytorch implementation in comfyui (https://github.com/Fannovel16/ComfyUI-Frame-Interpolation/tree/main/vfi_models/rife)

is there a way to build tensorrt engines from your models, which are compatible for comfyui?

i've noticed the onnx input to be tensor: float16[batch_size,8,width,height], and not sure how to convert my inputs to this format, in context of comfyui.

Any help will be appreciated. Thanks.

styler00dollar commented 3 weeks ago

Rife v1 onnx have 8 inputs and rife v2 have 7 inputs. To use the onnx you can adjust your input to be img1 (3ch) + img2 (3ch) + timestep (1ch, one number) + [scale (1ch, one number)] in the channel dimension. The v2 models have scale 0.5 hardcoded.

One example with torch would be:

h, w = 128, 128
img0 = torch.randn(1, 3, h, w)
img1 = torch.randn(1, 3, h, w)
timestep = torch.randn(1, 1, h, w)
input = torch.cat((img0, img1, timestep), dim=1)

For inference, you can use onnxruntime with the trt backend. You could also avoid using onnx and directly use torch trt.

yuvraj108c commented 3 weeks ago

ok, thanks for replying how do I get the interpolated image (3ch) from the engine output?

styler00dollar commented 3 weeks ago

The input shape is 7/8 channels and the output shape is 3 channels.

yuvraj108c commented 1 week ago

@styler00dollar I've tried to do inference with the models, but the model seems to accept only input with shape (1,7,1,1) and output with shape (1,3,1,1), otherwise it fails

Models tested: rife46_v2_ensembleTrue_op16_mlrt_sim, rife415_v2_ensembleFalse_op20_clamp

h, w = 1, 1
img0 = torch.randn(1, 3, h, w)
img1 = torch.randn(1, 3, h, w)
timestep = torch.randn(1, 1, h, w)
input = torch.cat((img0, img1, timestep), dim=1)

Let me know if I've used the wrong models. Error message:

[E] IExecutionContext::setInputShape: Error Code 3: API Usage Error (Parameter check failed, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape for input. Set dimensions are [1,7,128,128]. Expected dimensions are [1,7,1,1].)

    self.tensors[name].copy_(buf)
RuntimeError: output with shape [1, 7, 1, 1] doesn't match the broadcast shape [1, 7, 128, 128]
styler00dollar commented 1 week ago

Are you using TensorRT? Sounds like you built a static engine with 1 pixel dimensions. Try cuda or different engine sizes.

yuvraj108c commented 1 week ago

Are you using TensorRT? Sounds like you built a static engine with 1 pixel dimensions. Try cuda or different engine sizes.

yep, I wasn't building the engine with dynamic in/out shapes, it works now

Another question, the timestep is a float value e.g 0.5

I've used timestep = torch.randn(1, 1, h, w) for testing, as you said

Now how do I convert the actual timestep (0.5) to the above expected format?

Is this correct?

def float_to_tensor(value, height, width):
    return torch.full((1, 1, height, width), value)
styler00dollar commented 1 week ago

That is a tensor with the width and height of the image, it is one dimensional and has one number everywhere. The timestep will depend on where you want rife to be, 0.5 means to be exactly between frame 1 and 2. Internally it is just this:

def forward(self, input):
    img0 = input[:, :3]
    img1 = input[:, 3:6]
    timestep = input[:, 6:7][0][0][0][0]
    scale = input[:, 7:8][0][0][0][0]

Your code seems to do it correctly.

yuvraj108c commented 1 week ago

Your code seems to do it correctly.

ok

Isn't batching supported?

I've built the engine with dynamic batches, inference works with no errors, but the outputs return repeating/very similar frames, no visible interpolation, multiplier >= 2

only batch size 1 is giving correct outputs.

batch size 1 batch 1 batch size 2 batch 2 batch size 3 batch 3 batch size 4 batch 4 batch size 5 batch 5

styler00dollar commented 1 week ago

I did not test above batch 1.