styler00dollar / VSGAN-tensorrt-docker

Using VapourSynth with super resolution and interpolation models and speeding them up with TensorRT.
BSD 3-Clause "New" or "Revised" License
286 stars 30 forks source link

How to convert rife's onnx model? #56

Closed jhl13 closed 10 months ago

jhl13 commented 10 months ago

I found the converted rife onnx model at, but the onnx model input here is slightly different from the input of Practical-RIFE. It seems that the timestep is integrated into the input. Can you share the code for rife onnx model conversion? I didn't find this code in the project.

styler00dollar commented 10 months ago

I write a new export script every time the architecture changes, which is nearly every rife version. I do some adjustments so I can properly use it with core.trt from mlrt, since that only allows one input. The first 6 channels are the 2 input images, then it's one channel for timestep and one channel for scale. You can see it in

The file will change depending on the version, but with 4.12 as an example:

class IFBlock(nn.Module):
    def __init__(self, in_planes, c=64):
    def forward(self, x, flow=None, scale=1):
            scale = scale.item()

class IFNet(nn.Module):
    def __init__(self):

    def forward(self, input, fastmode=True, ensemble=False):
        input = torch.clamp(input, 0, 1)
        img0 = input[:, :3]
        img1 = input[:, 3:6]
        timestep = input[:, 6:7][0][0][0][0]
        scale = input[:, 7:8][0][0][0][0]
        scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
        n, c, h, w = img0.shape
        ph = ((h - 1) // 64 + 1) * 64
        pw = ((w - 1) // 64 + 1) * 64
        padding = (0, pw - w, 0, ph - h)
        img0 = F.pad(img0, padding)
        img1 = F.pad(img1, padding)
        x =, img1), 1)
        timestep = (x[:, :1].clone() * 0 + 1) * timestep
        timestep = timestep.float()
        return merged[3][:, :, :h, :w]

def convert(param, rank=-1):
    if rank == -1:
        return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
        return param

model = IFNet()
state_dict = convert(torch.load("flownet.pkl", map_location="cpu"))

model.load_state_dict(state_dict, strict=False)
#, "resaved_rife.pth")

with torch.inference_mode():
    dynamic_axes = {
        "input": {0: "batch_size", 2: "width", 3: "height"},
        "output": {0: "batch_size", 2: "width", 3: "height"},

        model.cuda(),[torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda(),

For fp16 you will need further adjustments:

class IFBlock(nn.Module):
    def forward(self, x, flow=None, scale=1):
        feat = self.conv0(x.half())

class IFNet(nn.Module):
    def forward(self, input, fastmode=True, ensemble=False):
        timestep = timestep.half()
jhl13 commented 10 months ago

Thank you, I roughly understand the onnx conversion process.

Djdefrag commented 8 months ago

@styler00dollar Hi my friend,

sorry to bother, can you share the full modified code of RIFE please?

And another question, i see this dummy input for onnx export[torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda()

and when i have to use the onnx model for inference, what are the input. for example:

styler00dollar commented 8 months ago

Djdefrag commented 8 months ago

Hi everyone, sorry to bother again.

once I convert RIFE to onnx, I then have to create the input via numpy for inference:

def concatenate_frames(
        frame_1: numpy_ndarray, 
        frame_2: numpy_ndarray, 
    ) -> numpy_ndarray:

    height, width = get_image_resolution(frame_1)

    input_images = numpy_concatenate((frame_1, frame_2), axis=2)
    timestep = numpy_ones((height, width, 1))
    scale    = numpy_ones((height, width, 1))
    result = numpy_concatenate((input_images, timestep, scale), axis=2)

is that correct?

styler00dollar commented 7 months ago

The input shape with the onnx code i showed is 1 (or different batch), 8, height, width.

lgthappy commented 2 months ago

Hi everyone, do you have the source code for onnx conversion that worked? I've been trying to follow the steps for ages and never got it right.