styler00dollar / VSGAN-tensorrt-docker

Using VapourSynth with super resolution and interpolation models and speeding them up with TensorRT.
BSD 3-Clause "New" or "Revised" License
286 stars 30 forks source link

How to convert rife's onnx model? #56

Closed jhl13 closed 10 months ago

jhl13 commented 10 months ago

I found the converted rife onnx model at https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases, but the onnx model input here is slightly different from the input of Practical-RIFE. It seems that the timestep is integrated into the input. Can you share the code for rife onnx model conversion? I didn't find this code in the project.

styler00dollar commented 10 months ago

I write a new export script every time the architecture changes, which is nearly every rife version. I do some adjustments so I can properly use it with core.trt from mlrt, since that only allows one input. The first 6 channels are the 2 input images, then it's one channel for timestep and one channel for scale. You can see it in rife_trt.py.

The file will change depending on the version, but with 4.12 as an example:

class IFBlock(nn.Module):
    def __init__(self, in_planes, c=64):
        ...
    def forward(self, x, flow=None, scale=1):
        try:
            scale = scale.item()
        except:
            pass
        ...

class IFNet(nn.Module):
    def __init__(self):
        ...

    def forward(self, input, fastmode=True, ensemble=False):
        input = torch.clamp(input, 0, 1)
        img0 = input[:, :3]
        img1 = input[:, 3:6]
        timestep = input[:, 6:7][0][0][0][0]
        scale = input[:, 7:8][0][0][0][0]
        scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
        n, c, h, w = img0.shape
        ph = ((h - 1) // 64 + 1) * 64
        pw = ((w - 1) // 64 + 1) * 64
        padding = (0, pw - w, 0, ph - h)
        img0 = F.pad(img0, padding)
        img1 = F.pad(img1, padding)
        x = torch.cat((img0, img1), 1)
        timestep = (x[:, :1].clone() * 0 + 1) * timestep
        timestep = timestep.float()
        ...
        return merged[3][:, :, :h, :w]

def convert(param, rank=-1):
    if rank == -1:
        return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
    else:
        return param

model = IFNet()
model.eval()
state_dict = convert(torch.load("flownet.pkl", map_location="cpu"))

model.load_state_dict(state_dict, strict=False)
# torch.save(model.state_dict(), "resaved_rife.pth")

with torch.inference_mode():
    dynamic_axes = {
        "input": {0: "batch_size", 2: "width", 3: "height"},
        "output": {0: "batch_size", 2: "width", 3: "height"},
    }

    torch.onnx.export(
        model.cuda(),
        torch.cat([torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda(),
        "rife412_fastTrue_ensembleFalse_op18_clamp.onnx",
        verbose=False,
        opset_version=18,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes=dynamic_axes,
    )

For fp16 you will need further adjustments:


class IFBlock(nn.Module):
    def forward(self, x, flow=None, scale=1):
        ...
        feat = self.conv0(x.half())
        ...

class IFNet(nn.Module):
    def forward(self, input, fastmode=True, ensemble=False):
        ...
        timestep = timestep.half()
        ...
jhl13 commented 10 months ago

Thank you, I roughly understand the onnx conversion process.

Djdefrag commented 8 months ago

@styler00dollar Hi my friend,

sorry to bother, can you share the full modified code of RIFE please?

And another question, i see this dummy input for onnx export

torch.cat([torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda()

and when i have to use the onnx model for inference, what are the input. for example:

styler00dollar commented 8 months ago

https://github.com/styler00dollar/VSGAN-tensorrt-docker/issues/56#issuecomment-1893040276

Djdefrag commented 8 months ago

Hi everyone, sorry to bother again.

once I convert RIFE to onnx, I then have to create the input via numpy for inference:

def concatenate_frames(
        frame_1: numpy_ndarray, 
        frame_2: numpy_ndarray, 
    ) -> numpy_ndarray:

    height, width = get_image_resolution(frame_1)

    input_images = numpy_concatenate((frame_1, frame_2), axis=2)
    timestep = numpy_ones((height, width, 1))
    scale    = numpy_ones((height, width, 1))
    result = numpy_concatenate((input_images, timestep, scale), axis=2)

is that correct?

styler00dollar commented 7 months ago

The input shape with the onnx code i showed is 1 (or different batch), 8, height, width.

lgthappy commented 2 months ago

Hi everyone, do you have the source code for onnx conversion that worked? I've been trying to follow the steps for ages and never got it right.