onnx / onnx

Open standard for machine learning interoperability
https://onnx.ai/
Apache License 2.0
17.75k stars 3.67k forks source link

[ONNX] Onnx file loads very slowly with onnxruntime, which is exported (dynamic_shapes=True) with `torch.onnx_dynamo_export`. #6016

Open Lingstreasure opened 7 months ago

Lingstreasure commented 7 months ago

I trained an inpainting model which has torch.rfftn / torch.irfftn modules and accepts image data with shape-[b, 4, h, w]. For some reason the torch.onnx.export can't export operators with complex tenors. I tried to make dynamic export successfully with torch.onnx.dynamo_export, but it takes a long time for onnxruntime to load it, here is my model: onnx

environment: ``` os: Ubuntu 20.04.5 LTS onnx==1.14.1 onnxruntime==1.16.0 onnxscript==0.1.0.dev20240304 torch==2.1.1+cu12.1+cudnn8.9.2 ```


model.py: ```python # Fast Fourier Convolution NeurIPS 2020 # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf import torch import torch.nn as nn class FourierUnit(nn.Module): def __init__(self, in_channels, out_channels, groups=1, fft_norm='ortho', norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU): super(FourierUnit, self).__init__() self.groups = groups self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2, kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) self.bn = norm_layer(out_channels * 2) self.relu = activation_layer(True) self.fft_norm = fft_norm def forward(self, x): batch, channel, h, w = x.shape fft_dim = (-2, -1) ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) ffted = torch.stack((ffted.real, ffted.imag), dim=-1) # (b, c, h, w/2+1, 2) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (b, c, 2, h, w/2+1) ffted = ffted.view((batch, 2 * channel, h, -1)) # (b, 2c, h, w/2+1) ffted = self.conv_layer(ffted) # (b, 2c, h, w/2+1) ffted = self.relu(self.bn(ffted)) ffted = ffted.view((batch, channel, 2, h, -1)).permute(0, 1, 3, 4, 2).contiguous() # (b, c, h, w/2+1, 2) ffted = torch.complex(ffted[..., 0], ffted[..., 1]) # (b, c, h, w/2+1) ifft_shape_slice = x.shape[-2:] output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) # (b, c, h, w) return output class SpectralTransform(nn.Module): def __init__(self, in_channels, out_channels, stride=1, groups=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU, **fu_kwargs): super(SpectralTransform, self).__init__() if stride == 2: self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) else: self.downsample = nn.Identity() self.stride = stride self.conv1 = nn.Sequential( nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False), norm_layer(out_channels // 2), activation_layer(True) ) self.fu = FourierUnit( out_channels // 2, out_channels // 2, groups, norm_layer=norm_layer, activation_layer=activation_layer, **fu_kwargs) self.conv2 = torch.nn.Conv2d( out_channels // 2, out_channels, kernel_size=1, groups=groups) def forward(self, x): x = self.downsample(x) x = self.conv1(x) output = self.fu(x) output = self.conv2(x + output) return output class FFC(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride=1, padding=0, dilation=1, groups=1, bias=False, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU, padding_type='reflect', **spectral_kwargs): super(FFC, self).__init__() assert stride == 1 or stride == 2, "Stride should be 1 or 2." self.stride = stride in_cg = int(in_channels * ratio_gin) in_cl = in_channels - in_cg out_cg = int(out_channels * ratio_gout) out_cl = out_channels - out_cg self.ratio_gin = ratio_gin self.ratio_gout = ratio_gout self.global_in_num = in_cg module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d self.convl2l = module(in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type) module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d self.convl2g = module(in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type) module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d self.convg2l = module(in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type) module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform self.convg2g = module( in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, norm_layer=norm_layer, activation_layer=activation_layer, **spectral_kwargs) module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d def forward(self, x): x_l, x_g = x if isinstance(x, tuple) else (x, torch.tensor(0.0)) out_xl, out_xg = torch.tensor(0.0), torch.tensor(0.0) if self.ratio_gout != 1: out_xl = self.convl2l(x_l) + self.convg2l(x_g) if self.ratio_gout != 0: out_xg = self.convl2g(x_l) + self.convg2g(x_g) return out_xl, out_xg class FFC_BN_ACT(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride=1, padding=0, dilation=1, groups=1, bias=False, norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity, padding_type='reflect', **kwargs): super(FFC_BN_ACT, self).__init__() self.ffc = FFC(in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding, dilation, groups, bias, norm_layer, activation_layer, padding_type=padding_type, **kwargs) lnorm = nn.Identity if ratio_gout == 1 else norm_layer gnorm = nn.Identity if ratio_gout == 0 else norm_layer global_channels = int(out_channels * ratio_gout) self.bn_l = lnorm(out_channels - global_channels) self.bn_g = gnorm(global_channels) lact = nn.Identity if ratio_gout == 1 else activation_layer gact = nn.Identity if ratio_gout == 0 else activation_layer self.act_l = lact(inplace=True) self.act_g = gact(inplace=True) def forward(self, x): x_l, x_g = self.ffc(x) x_l = self.act_l(self.bn_l(x_l)) x_g = self.act_g(self.bn_g(x_g)) return x_l, x_g class FFCResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, **conv_kwargs): super().__init__() self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, norm_layer=norm_layer, activation_layer=activation_layer, padding_type=padding_type, **conv_kwargs) self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, norm_layer=norm_layer, activation_layer=activation_layer, padding_type=padding_type, **conv_kwargs) def forward(self, x): x_l, x_g = x if type(x) is tuple else (x, 0) id_l, id_g = x_l, x_g x_l, x_g = self.conv1((x_l, x_g)) x_l, x_g = self.conv2((x_l, x_g)) x_l, x_g = id_l + x_l, id_g + x_g out = x_l, x_g return out class ConcatTupleLayer(nn.Module): def forward(self, x): assert isinstance(x, tuple) x_l, x_g = x assert torch.is_tensor(x_l) or torch.is_tensor(x_g) if not torch.is_tensor(x_g): return x_l return torch.cat(x, dim=1) class FFCResNetGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, padding_type='reflect', activation_layer=nn.ReLU, up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={}, add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}): assert (n_blocks >= 0) super().__init__() model = [nn.ReflectionPad2d(3), FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer, activation_layer=activation_layer, **init_conv_kwargs)] ### downsample for i in range(n_downsampling): mult = 2 ** i if i == n_downsampling - 1: cur_conv_kwargs = dict(downsample_conv_kwargs) cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0) else: cur_conv_kwargs = downsample_conv_kwargs model += [FFC_BN_ACT(min(max_features, ngf * mult), min(max_features, ngf * mult * 2), kernel_size=3, stride=2, padding=1, norm_layer=norm_layer, activation_layer=activation_layer, **cur_conv_kwargs)] mult = 2 ** n_downsampling feats_num_bottleneck = min(max_features, ngf * mult) ### resnet blocks for i in range(n_blocks): cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer, norm_layer=norm_layer, **resnet_conv_kwargs) model += [cur_resblock] model += [ConcatTupleLayer()] ### upsample for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(min(max_features, ngf * mult), min(max_features, int(ngf * mult / 2)), kernel_size=3, stride=2, padding=1, output_padding=1), up_norm_layer(min(max_features, int(ngf * mult / 2))), up_activation] if out_ffc: model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer, norm_layer=norm_layer, inline=True, **out_ffc_kwargs)] model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] if add_out_act: model.append(nn.Sigmoid()) self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) ```


export.py: ```python import torch from models import FFCResNetGenerator if __name__ == "__main__": model = FFCResNetGenerator( input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=9, init_conv_kwargs={ "ratio_gin": 0, "ratio_gout": 0, }, downsample_conv_kwargs={ "ratio_gin": 0, "ratio_gout": 0, }, resnet_conv_kwargs={ "ratio_gin": 0.75, "ratio_gout": 0.75, } ) model.eval() input_data = torch.randn(1, 4, 512, 1024) args = (input_data,) export_options = torch.onnx.ExportOptions(dynamic_shapes=True) torch.onnx.dynamo_export( model, *args, export_options=export_options, ).save("dynamic_fft.onnx") print(f"Dynamic onnx exported to dynamic_fft.onnx") ```

Generally, this code will raise an error when executed:

warnings.warn(
Traceback (most recent call last):
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1397, in run_node
    return node.target(*args, **kwargs)
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 475, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
torch._subclasses.fake_tensor.UnsupportedOperatorException: aten._fft_r2c.default



I modified some codes to make it pass, and successfully export a dynamic-model onnx. But it takes about 1min to load the dynamic-model onnx file with onnxruntime for inference, which is too slowly and can't be accept in my task.

For dynamic exporting, I modified several codes as follows:

  1. comment 2 line of codes in _subclasses/fake_tensor.py of torch (2.1.1) package, find the function stride_incorrect_op:

    def stride_incorrect_op(op):
       if op.namespace not in ("aten", "prims"):
           return False
       if op is aten._fft_c2c.default:
           return False
    
       op_name = op.name()
       # if "fft" in op_name:  ### comment the condition expr
       #     return True
       return False


    Then, execute the export.py will raise another error:

    Traceback (most recent call last):
     File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/exporter.py", line 1190, in dynamo_export
       return Exporter(
     File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/exporter.py", line 950, in export
       graph_module = pre_export_passes(
     File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/exporter.py", line 1250, in pre_export_passes
       analysis.UnsupportedFxNodesAnalysis(
     File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 74, in analyze
       self._lint(analysis_result, diagnostic_level)
     File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 38, in _lint
       self.diagnostic_context.log_and_raise_if_error(diagnostic)
     File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/infra/context.py", line 367, in log_and_raise_if_error
       raise RuntimeErrorWithDiagnostic(diagnostic)
    torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.complex.default']}.


    I register the function complex() in my export.py:

    export.py: ```python import onnxscript import torch from onnxscript import FLOAT, COMPLEX64 from torch.onnx import register_custom_op_symbolic from model import FFCResNetGenerator def register_complex_for_torch_dynamo(): from onnxscript.onnx_opset import opset18 as op custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1) @onnxscript.script(custom_aten) def custom_aten_complex( real: FLOAT[1, "C", "H", "W"], imag: FLOAT[1, "C", "H", "W"] ) -> COMPLEX64[1, "C", "H", "W", 2]: real = op.Unsqueeze(real, axes=[-1]) imag = op.Unsqueeze(imag, axes=[-1]) return op.Concat(real, imag, axis=-1) # register 'aten::complex' onnx_registry = torch.onnx.OnnxRegistry() onnx_registry.register_op(namespace="aten", op_name="complex", function=custom_aten_complex) print(f"aten::complex is supported by ONNX registry: \ {onnx_registry.is_registered_op(namespace='aten', op_name='complex')}" ) return onnx_registry if __name__ == "__main__": model = FFCResNetGenerator( input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=9, init_conv_kwargs={ "ratio_gin": 0, "ratio_gout": 0, }, downsample_conv_kwargs={ "ratio_gin": 0, "ratio_gout": 0, }, resnet_conv_kwargs={ "ratio_gin": 0.75, "ratio_gout": 0.75, } ) model.eval() input_data = torch.randn(1, 4, 512, 1024) args = (input_data,) export_options = torch.onnx.ExportOptions( onnx_registry=register_complex_for_torch_dynamo(), ### add here dynamic_shapes=True ) torch.onnx.dynamo_export( model, *args, export_options=export_options, ).save("dynamic_fft.onnx") print(f"Dynamic onnx exported to dynamic_fft.onnx") ```


  2. For dynamic shape inference, in function_libs/torch_lib/ops/fft.py of package onnxscript in virtual environment, I add a function _ifftn_onnx():

    _ifftn_onnx(): ```python @torch_op( "aten::_fft_c2r", trace_only=True, private=True, complex=True, ) def _ifftn_onnx( self: TFloat, dims: Sequence[int], normalization: int, last_dim_size: INT64 ) -> TFloat: """Standard complex to real inverse FFT. Args: self: The input tensor. dims: The dimensions to apply FFT. normalization: The normalization mode. inverse: Whether to compute the inverse FFT. last_dim_size: The size of last dim Returns: The transformed tensor. """ # my model inputs are images, which have shape: [batch, c, h, w] # so in this function, the `self` tensor will have a shape: [batch, c, h, w or w/2+1, 2] # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new # dimension at the beginning to represent the batch dimension. transformed = op.Unsqueeze(self, axes=[0]) # Add 1 to account for the batch dimension when counting axes from the left new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] for dim in new_dims[:-1]: transformed = op.DFT(transformed, axis=dim, inverse=True, onesided=False) # Torch computers one-sided FFT on the last dimension only. ###################################################################################### # There is an error in DFT opeartor when `inverse` and `onesided` are both True: # Op (DFT) [ShapeInferenceError] is_onesided and inverse attributes cannot be enabled # at the same time ##################################################################################### # **** custom irfft implementation **** # make conjugate for reverse RFFT # the output size of rfft will be x/2 + 1, so complete the conjugate part first. transformed_conj = transformed * op.Constant(value_floats=[1.0, -1.0]) # flip the conjugate part transformed_conj = op.Transpose(transformed_conj, perm=[4, 0, 1, 2, 3, 5]) sequence_len = op.CastLike(last_dim_size / 2 + 1, last_dim_size) sequence_lens = op.Expand(sequence_len, shape=[1]) transformed_conj = op.ReverseSequence( transformed_conj, batch_axis=1, time_axis=0, sequence_lens=sequence_lens ) transformed_conj = op.Transpose(transformed_conj, perm=[1, 2, 3, 4, 0, 5]) # slice out the needed part # my input `self` tensor sizes are always evens. starts = op.Constant(value_ints=[0, 0, 0, 0, 1, 0]) transformed_conj = op.Slice( transformed_conj, starts=starts, ends=op.Shape(transformed) ) # concatenate with original positive part transformed = op.Concat(transformed, transformed_conj, axis=new_dims[-1]) transformed = op.DFT( transformed, last_dim_size, axis=new_dims[-1], inverse=True, onesided=False ) # Remove the batch dimension transformed = op.Squeeze(transformed, axes=[0]) ### Normalize the result. The followed code will raise error, I implement normalization in my model. # ifft of DFT in ONNX has already normed with 1/n (test for sure), so we should `*n` first if `forward` is False # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 # total_sample_count = last_dim_size # for dim_ in dims[:-1]: # total_sample_count = total_sample_count * self_shape[dim_]#op.Constant(value_int=self_shape[dim_]) # total_sample_count = op.CastLike(total_sample_count, transformed) # if normalization == 1: # # "ortho" - normalize by 1/sqrt(n) # transformed = op.Mul(transformed, op.Sqrt(total_sample_count)) # elif normalization == 2: # # "forward" - normalize by 1/n # transformed = op.Mul(transformed, total_sample_count) return transformed ```

    reference:

    func: _fftn_onnx in function_libs/torch_lib/ops/fft.py of onnxscipt package

    numpy


    Then, find the function aten__fft_c2r() , replace the original implementation.

    @torch_op("aten::_fft_c2r", trace_only=True, complex=True)
    def aten__fft_c2r(
       self: TFloat,
       dim: Sequence[int],
       normalization: int,
       last_dim_size: INT64,  # pylint: disable=unused-argument
    ) -> TFloat:
       """_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
    
       Complex to real inverse FFT.
       """
    
       self_rank = len(self.shape)
       # ONNX DFT input assumes the last dimension is the complex dimension.
       # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
       dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
       # transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=True)  ### comment this line
       transformed = _ifftn_onnx(self, dim, normalization, last_dim_size=last_dim_size)   ### add this one
       # Take only the real part
       real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])
    
       return op.Squeeze(real_part, axes=[-1])


  3. Last, for the sake of the correct result, I have to finish the normalization of _ifftn_onnx() (not finished in 2.) in my model:

    in FourierUnit of model.py: ```python class FourierUnit(nn.Module): def __init__(self, in_channels, out_channels, groups=1, fft_norm='ortho', norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU): super(FourierUnit, self).__init__() self.groups = groups self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2, kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) self.bn = norm_layer(out_channels * 2) self.relu = activation_layer(True) self.fft_norm = fft_norm def forward(self, x): batch, channel, h, w = x.shape fft_dim = (-2, -1) ffted = torch.fft.rfftn(x, dim=fft_dim, norm="ortho") ### set to `ortho` ffted = torch.stack((ffted.real, ffted.imag), dim=-1) # (b, c, h, w/2+1, 2) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (b, c, 2, h, w/2+1) ffted = ffted.view((batch, 2 * channel, h, -1)) # (b, 2c, h, w/2+1) ffted = self.conv_layer(ffted) # (b, 2c, h, w/2+1) ffted = self.relu(self.bn(ffted)) ffted = ffted.view((batch, channel, 2, h, -1)).permute(0, 1, 3, 4, 2).contiguous() # (b, c, h, w/2+1, 2) ffted = torch.complex(ffted[..., 0], ffted[..., 1]) # (b, c, h, w/2+1) ifft_shape_slice = x.shape[-2:] output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) # (b, c, h, w) ### I use "ortho" normalization through my model. output = output * torch.sqrt(torch.tensor(h * w, requires_grad=False)) # add this line return output ```


After the 3 steps, I successfully export my dynamic model, but it is very slowly using onnxruntime for inference when execute ort.InferenceSession(onnx_file, providers=['CPUExecutionProvider']). I don't know how to handle this, here is the visualization of my onnx file:

visualization: ![rfftn-irfftn onnx](https://github.com/onnx/onnx/assets/73473905/15e5d386-48df-40f6-82e1-c385c22a3614)


It has a big subgraph in it due to torch._dynamo ? Maybe it's the reason why onnxruntime loading the onnx file so slowly? Would anyone give some help?

justinchuby commented 6 months ago

The models produced by dynamo_export are known to run slowly (for now) because they are unoptimized. The api is in beta, and we intend to provide tools to optimize these models for onnxruntime soon.

gramalingam commented 6 months ago

Is your concern about model-loading time or inference run-time? It may help to also report this in the onnxruntime repo and/or the pytorch exporter repo. As Justin says above, the transition to dynamo-exporter is in progress (and these concerns should be addressed soon).