Closed jhl13 closed 10 months ago
I write a new export script every time the architecture changes, which is nearly every rife version. I do some adjustments so I can properly use it with core.trt
from mlrt, since that only allows one input. The first 6 channels are the 2 input images, then it's one channel for timestep
and one channel for scale
. You can see it in rife_trt.py
.
The file will change depending on the version, but with 4.12 as an example:
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
...
def forward(self, x, flow=None, scale=1):
try:
scale = scale.item()
except:
pass
...
class IFNet(nn.Module):
def __init__(self):
...
def forward(self, input, fastmode=True, ensemble=False):
input = torch.clamp(input, 0, 1)
img0 = input[:, :3]
img1 = input[:, 3:6]
timestep = input[:, 6:7][0][0][0][0]
scale = input[:, 7:8][0][0][0][0]
scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
n, c, h, w = img0.shape
ph = ((h - 1) // 64 + 1) * 64
pw = ((w - 1) // 64 + 1) * 64
padding = (0, pw - w, 0, ph - h)
img0 = F.pad(img0, padding)
img1 = F.pad(img1, padding)
x = torch.cat((img0, img1), 1)
timestep = (x[:, :1].clone() * 0 + 1) * timestep
timestep = timestep.float()
...
return merged[3][:, :, :h, :w]
def convert(param, rank=-1):
if rank == -1:
return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
else:
return param
model = IFNet()
model.eval()
state_dict = convert(torch.load("flownet.pkl", map_location="cpu"))
model.load_state_dict(state_dict, strict=False)
# torch.save(model.state_dict(), "resaved_rife.pth")
with torch.inference_mode():
dynamic_axes = {
"input": {0: "batch_size", 2: "width", 3: "height"},
"output": {0: "batch_size", 2: "width", 3: "height"},
}
torch.onnx.export(
model.cuda(),
torch.cat([torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda(),
"rife412_fastTrue_ensembleFalse_op18_clamp.onnx",
verbose=False,
opset_version=18,
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
)
For fp16 you will need further adjustments:
class IFBlock(nn.Module):
def forward(self, x, flow=None, scale=1):
...
feat = self.conv0(x.half())
...
class IFNet(nn.Module):
def forward(self, input, fastmode=True, ensemble=False):
...
timestep = timestep.half()
...
Thank you, I roughly understand the onnx conversion process.
@styler00dollar Hi my friend,
sorry to bother, can you share the full modified code of RIFE please?
And another question, i see this dummy input for onnx export
torch.cat([torch.rand(1, 6, 256, 256), torch.ones(1, 2, 256, 256)], 1).cuda()
and when i have to use the onnx model for inference, what are the input. for example:
Hi everyone, sorry to bother again.
once I convert RIFE to onnx, I then have to create the input via numpy for inference:
def concatenate_frames(
frame_1: numpy_ndarray,
frame_2: numpy_ndarray,
) -> numpy_ndarray:
height, width = get_image_resolution(frame_1)
input_images = numpy_concatenate((frame_1, frame_2), axis=2)
timestep = numpy_ones((height, width, 1))
scale = numpy_ones((height, width, 1))
result = numpy_concatenate((input_images, timestep, scale), axis=2)
is that correct?
The input shape with the onnx code i showed is 1 (or different batch), 8, height, width
.
Hi everyone, do you have the source code for onnx conversion that worked? I've been trying to follow the steps for ages and never got it right.
I found the converted rife onnx model at https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases, but the onnx model input here is slightly different from the input of Practical-RIFE. It seems that the timestep is integrated into the input. Can you share the code for rife onnx model conversion? I didn't find this code in the project.