PeterL1n / RobustVideoMatting

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!
https://peterl1n.github.io/RobustVideoMatting/
GNU General Public License v3.0
8.32k stars 1.11k forks source link

ONNX export is invalid #238

Open skottmckay opened 1 year ago

skottmckay commented 1 year ago

In ONNX a named dimension must have a single value when the model is executed. Due to that the input/output shapes for the recurrent values are incorrect.

https://github.com/PeterL1n/RobustVideoMatting/blob/ebead27cb683e157b2bea7ca869daa820a07ba8f/export_onnx.py#LL49C1-L73C15

I believe this should look more like the below, assuming src, fgr and pha inputs have exactly the same shape when the model is run:

        dynamic_spatial = {0: 'batch_size', 2: 'height', 3: 'width'}
        dynamic_r1 = {0: 'batch_size', 1: 'r1_channels', 2: 'r1_height', 3: 'r1_width'}
        dynamic_r2 = {0: 'batch_size', 1: 'r2_channels', 2: 'r2_height', 3: 'r2_width'}
        dynamic_r3 = {0: 'batch_size', 1: 'r3_channels', 2: 'r3_height', 3: 'r3_width'}
        dynamic_r4 = {0: 'batch_size', 1: 'r4_channels', 2: 'r4_height', 3: 'r4_width'}

        torch.onnx.export(
            self.model,
            (src, *rec, downsample_ratio),
            self.args.output,
            export_params=True,
            opset_version=self.args.opset,
            do_constant_folding=True,
            input_names=['src', 'r1i', 'r2i', 'r3i', 'r4i', 'downsample_ratio'],
            output_names=['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o'],
            dynamic_axes={
                'src': dynamic_spatial,
                'fgr': dynamic_spatial,
                'pha': dynamic_spatial,
                'r1i': dynamic_r1,
                'r2i': dynamic_r2,
                'r3i': dynamic_r3,
                'r4i': dynamic_r4,
                'r1o': dynamic_r1,
                'r2o': dynamic_r2,
                'r3o': dynamic_r3,
                'r4o': dynamic_r4,
            })

https://github.com/microsoft/onnxruntime/issues/9433#issuecomment-947214520