Open Lingstreasure opened 8 months ago
The models produced by dynamo_export are known to run slowly (for now) because they are unoptimized. The api is in beta, and we intend to provide tools to optimize these models for onnxruntime soon.
Is your concern about model-loading time or inference run-time? It may help to also report this in the onnxruntime repo and/or the pytorch exporter repo. As Justin says above, the transition to dynamo-exporter is in progress (and these concerns should be addressed soon).
I trained an inpainting model which has
torch.rfftn
/torch.irfftn
modules and accepts image data with shape-[b, 4, h, w]. For some reason thetorch.onnx.export
can't export operators with complex tenors. I tried to make dynamic export successfully withtorch.onnx.dynamo_export
, but it takes a long time for onnxruntime to load it, here is my model: onnxenvironment:
``` os: Ubuntu 20.04.5 LTS onnx==1.14.1 onnxruntime==1.16.0 onnxscript==0.1.0.dev20240304 torch==2.1.1+cu12.1+cudnn8.9.2 ```model.py:
```python # Fast Fourier Convolution NeurIPS 2020 # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf import torch import torch.nn as nn class FourierUnit(nn.Module): def __init__(self, in_channels, out_channels, groups=1, fft_norm='ortho', norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU): super(FourierUnit, self).__init__() self.groups = groups self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2, kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) self.bn = norm_layer(out_channels * 2) self.relu = activation_layer(True) self.fft_norm = fft_norm def forward(self, x): batch, channel, h, w = x.shape fft_dim = (-2, -1) ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) ffted = torch.stack((ffted.real, ffted.imag), dim=-1) # (b, c, h, w/2+1, 2) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (b, c, 2, h, w/2+1) ffted = ffted.view((batch, 2 * channel, h, -1)) # (b, 2c, h, w/2+1) ffted = self.conv_layer(ffted) # (b, 2c, h, w/2+1) ffted = self.relu(self.bn(ffted)) ffted = ffted.view((batch, channel, 2, h, -1)).permute(0, 1, 3, 4, 2).contiguous() # (b, c, h, w/2+1, 2) ffted = torch.complex(ffted[..., 0], ffted[..., 1]) # (b, c, h, w/2+1) ifft_shape_slice = x.shape[-2:] output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) # (b, c, h, w) return output class SpectralTransform(nn.Module): def __init__(self, in_channels, out_channels, stride=1, groups=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU, **fu_kwargs): super(SpectralTransform, self).__init__() if stride == 2: self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) else: self.downsample = nn.Identity() self.stride = stride self.conv1 = nn.Sequential( nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False), norm_layer(out_channels // 2), activation_layer(True) ) self.fu = FourierUnit( out_channels // 2, out_channels // 2, groups, norm_layer=norm_layer, activation_layer=activation_layer, **fu_kwargs) self.conv2 = torch.nn.Conv2d( out_channels // 2, out_channels, kernel_size=1, groups=groups) def forward(self, x): x = self.downsample(x) x = self.conv1(x) output = self.fu(x) output = self.conv2(x + output) return output class FFC(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride=1, padding=0, dilation=1, groups=1, bias=False, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU, padding_type='reflect', **spectral_kwargs): super(FFC, self).__init__() assert stride == 1 or stride == 2, "Stride should be 1 or 2." self.stride = stride in_cg = int(in_channels * ratio_gin) in_cl = in_channels - in_cg out_cg = int(out_channels * ratio_gout) out_cl = out_channels - out_cg self.ratio_gin = ratio_gin self.ratio_gout = ratio_gout self.global_in_num = in_cg module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d self.convl2l = module(in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type) module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d self.convl2g = module(in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type) module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d self.convg2l = module(in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type) module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform self.convg2g = module( in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, norm_layer=norm_layer, activation_layer=activation_layer, **spectral_kwargs) module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d def forward(self, x): x_l, x_g = x if isinstance(x, tuple) else (x, torch.tensor(0.0)) out_xl, out_xg = torch.tensor(0.0), torch.tensor(0.0) if self.ratio_gout != 1: out_xl = self.convl2l(x_l) + self.convg2l(x_g) if self.ratio_gout != 0: out_xg = self.convl2g(x_l) + self.convg2g(x_g) return out_xl, out_xg class FFC_BN_ACT(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride=1, padding=0, dilation=1, groups=1, bias=False, norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity, padding_type='reflect', **kwargs): super(FFC_BN_ACT, self).__init__() self.ffc = FFC(in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding, dilation, groups, bias, norm_layer, activation_layer, padding_type=padding_type, **kwargs) lnorm = nn.Identity if ratio_gout == 1 else norm_layer gnorm = nn.Identity if ratio_gout == 0 else norm_layer global_channels = int(out_channels * ratio_gout) self.bn_l = lnorm(out_channels - global_channels) self.bn_g = gnorm(global_channels) lact = nn.Identity if ratio_gout == 1 else activation_layer gact = nn.Identity if ratio_gout == 0 else activation_layer self.act_l = lact(inplace=True) self.act_g = gact(inplace=True) def forward(self, x): x_l, x_g = self.ffc(x) x_l = self.act_l(self.bn_l(x_l)) x_g = self.act_g(self.bn_g(x_g)) return x_l, x_g class FFCResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, **conv_kwargs): super().__init__() self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, norm_layer=norm_layer, activation_layer=activation_layer, padding_type=padding_type, **conv_kwargs) self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, norm_layer=norm_layer, activation_layer=activation_layer, padding_type=padding_type, **conv_kwargs) def forward(self, x): x_l, x_g = x if type(x) is tuple else (x, 0) id_l, id_g = x_l, x_g x_l, x_g = self.conv1((x_l, x_g)) x_l, x_g = self.conv2((x_l, x_g)) x_l, x_g = id_l + x_l, id_g + x_g out = x_l, x_g return out class ConcatTupleLayer(nn.Module): def forward(self, x): assert isinstance(x, tuple) x_l, x_g = x assert torch.is_tensor(x_l) or torch.is_tensor(x_g) if not torch.is_tensor(x_g): return x_l return torch.cat(x, dim=1) class FFCResNetGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, padding_type='reflect', activation_layer=nn.ReLU, up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={}, add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}): assert (n_blocks >= 0) super().__init__() model = [nn.ReflectionPad2d(3), FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer, activation_layer=activation_layer, **init_conv_kwargs)] ### downsample for i in range(n_downsampling): mult = 2 ** i if i == n_downsampling - 1: cur_conv_kwargs = dict(downsample_conv_kwargs) cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0) else: cur_conv_kwargs = downsample_conv_kwargs model += [FFC_BN_ACT(min(max_features, ngf * mult), min(max_features, ngf * mult * 2), kernel_size=3, stride=2, padding=1, norm_layer=norm_layer, activation_layer=activation_layer, **cur_conv_kwargs)] mult = 2 ** n_downsampling feats_num_bottleneck = min(max_features, ngf * mult) ### resnet blocks for i in range(n_blocks): cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer, norm_layer=norm_layer, **resnet_conv_kwargs) model += [cur_resblock] model += [ConcatTupleLayer()] ### upsample for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(min(max_features, ngf * mult), min(max_features, int(ngf * mult / 2)), kernel_size=3, stride=2, padding=1, output_padding=1), up_norm_layer(min(max_features, int(ngf * mult / 2))), up_activation] if out_ffc: model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer, norm_layer=norm_layer, inline=True, **out_ffc_kwargs)] model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] if add_out_act: model.append(nn.Sigmoid()) self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) ```export.py:
```python import torch from models import FFCResNetGenerator if __name__ == "__main__": model = FFCResNetGenerator( input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=9, init_conv_kwargs={ "ratio_gin": 0, "ratio_gout": 0, }, downsample_conv_kwargs={ "ratio_gin": 0, "ratio_gout": 0, }, resnet_conv_kwargs={ "ratio_gin": 0.75, "ratio_gout": 0.75, } ) model.eval() input_data = torch.randn(1, 4, 512, 1024) args = (input_data,) export_options = torch.onnx.ExportOptions(dynamic_shapes=True) torch.onnx.dynamo_export( model, *args, export_options=export_options, ).save("dynamic_fft.onnx") print(f"Dynamic onnx exported to dynamic_fft.onnx") ```Generally, this code will raise an error when executed:
I modified some codes to make it pass, and successfully export a dynamic-model onnx. But it takes about 1min to load the dynamic-model onnx file with onnxruntime for inference, which is too slowly and can't be accept in my task.
For dynamic exporting, I modified several codes as follows:
comment 2 line of codes in _subclasses/fake_tensor.py of torch (2.1.1) package, find the function
stride_incorrect_op
:Then, execute the export.py will raise another error:
I register the function
complex()
in my export.py:export.py:
```python import onnxscript import torch from onnxscript import FLOAT, COMPLEX64 from torch.onnx import register_custom_op_symbolic from model import FFCResNetGenerator def register_complex_for_torch_dynamo(): from onnxscript.onnx_opset import opset18 as op custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1) @onnxscript.script(custom_aten) def custom_aten_complex( real: FLOAT[1, "C", "H", "W"], imag: FLOAT[1, "C", "H", "W"] ) -> COMPLEX64[1, "C", "H", "W", 2]: real = op.Unsqueeze(real, axes=[-1]) imag = op.Unsqueeze(imag, axes=[-1]) return op.Concat(real, imag, axis=-1) # register 'aten::complex' onnx_registry = torch.onnx.OnnxRegistry() onnx_registry.register_op(namespace="aten", op_name="complex", function=custom_aten_complex) print(f"aten::complex is supported by ONNX registry: \ {onnx_registry.is_registered_op(namespace='aten', op_name='complex')}" ) return onnx_registry if __name__ == "__main__": model = FFCResNetGenerator( input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=9, init_conv_kwargs={ "ratio_gin": 0, "ratio_gout": 0, }, downsample_conv_kwargs={ "ratio_gin": 0, "ratio_gout": 0, }, resnet_conv_kwargs={ "ratio_gin": 0.75, "ratio_gout": 0.75, } ) model.eval() input_data = torch.randn(1, 4, 512, 1024) args = (input_data,) export_options = torch.onnx.ExportOptions( onnx_registry=register_complex_for_torch_dynamo(), ### add here dynamic_shapes=True ) torch.onnx.dynamo_export( model, *args, export_options=export_options, ).save("dynamic_fft.onnx") print(f"Dynamic onnx exported to dynamic_fft.onnx") ```For dynamic shape inference, in function_libs/torch_lib/ops/fft.py of package onnxscript in virtual environment, I add a function
_ifftn_onnx()
:_ifftn_onnx():
```python @torch_op( "aten::_fft_c2r", trace_only=True, private=True, complex=True, ) def _ifftn_onnx( self: TFloat, dims: Sequence[int], normalization: int, last_dim_size: INT64 ) -> TFloat: """Standard complex to real inverse FFT. Args: self: The input tensor. dims: The dimensions to apply FFT. normalization: The normalization mode. inverse: Whether to compute the inverse FFT. last_dim_size: The size of last dim Returns: The transformed tensor. """ # my model inputs are images, which have shape: [batch, c, h, w] # so in this function, the `self` tensor will have a shape: [batch, c, h, w or w/2+1, 2] # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new # dimension at the beginning to represent the batch dimension. transformed = op.Unsqueeze(self, axes=[0]) # Add 1 to account for the batch dimension when counting axes from the left new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] for dim in new_dims[:-1]: transformed = op.DFT(transformed, axis=dim, inverse=True, onesided=False) # Torch computers one-sided FFT on the last dimension only. ###################################################################################### # There is an error in DFT opeartor when `inverse` and `onesided` are both True: # Op (DFT) [ShapeInferenceError] is_onesided and inverse attributes cannot be enabled # at the same time ##################################################################################### # **** custom irfft implementation **** # make conjugate for reverse RFFT # the output size of rfft will be x/2 + 1, so complete the conjugate part first. transformed_conj = transformed * op.Constant(value_floats=[1.0, -1.0]) # flip the conjugate part transformed_conj = op.Transpose(transformed_conj, perm=[4, 0, 1, 2, 3, 5]) sequence_len = op.CastLike(last_dim_size / 2 + 1, last_dim_size) sequence_lens = op.Expand(sequence_len, shape=[1]) transformed_conj = op.ReverseSequence( transformed_conj, batch_axis=1, time_axis=0, sequence_lens=sequence_lens ) transformed_conj = op.Transpose(transformed_conj, perm=[1, 2, 3, 4, 0, 5]) # slice out the needed part # my input `self` tensor sizes are always evens. starts = op.Constant(value_ints=[0, 0, 0, 0, 1, 0]) transformed_conj = op.Slice( transformed_conj, starts=starts, ends=op.Shape(transformed) ) # concatenate with original positive part transformed = op.Concat(transformed, transformed_conj, axis=new_dims[-1]) transformed = op.DFT( transformed, last_dim_size, axis=new_dims[-1], inverse=True, onesided=False ) # Remove the batch dimension transformed = op.Squeeze(transformed, axes=[0]) ### Normalize the result. The followed code will raise error, I implement normalization in my model. # ifft of DFT in ONNX has already normed with 1/n (test for sure), so we should `*n` first if `forward` is False # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 # total_sample_count = last_dim_size # for dim_ in dims[:-1]: # total_sample_count = total_sample_count * self_shape[dim_]#op.Constant(value_int=self_shape[dim_]) # total_sample_count = op.CastLike(total_sample_count, transformed) # if normalization == 1: # # "ortho" - normalize by 1/sqrt(n) # transformed = op.Mul(transformed, op.Sqrt(total_sample_count)) # elif normalization == 2: # # "forward" - normalize by 1/n # transformed = op.Mul(transformed, total_sample_count) return transformed ```reference:
Then, find the function
aten__fft_c2r()
, replace the original implementation.Last, for the sake of the correct result, I have to finish the normalization of
_ifftn_onnx()
(not finished in 2.) in my model:in FourierUnit of model.py:
```python class FourierUnit(nn.Module): def __init__(self, in_channels, out_channels, groups=1, fft_norm='ortho', norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU): super(FourierUnit, self).__init__() self.groups = groups self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2, kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) self.bn = norm_layer(out_channels * 2) self.relu = activation_layer(True) self.fft_norm = fft_norm def forward(self, x): batch, channel, h, w = x.shape fft_dim = (-2, -1) ffted = torch.fft.rfftn(x, dim=fft_dim, norm="ortho") ### set to `ortho` ffted = torch.stack((ffted.real, ffted.imag), dim=-1) # (b, c, h, w/2+1, 2) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (b, c, 2, h, w/2+1) ffted = ffted.view((batch, 2 * channel, h, -1)) # (b, 2c, h, w/2+1) ffted = self.conv_layer(ffted) # (b, 2c, h, w/2+1) ffted = self.relu(self.bn(ffted)) ffted = ffted.view((batch, channel, 2, h, -1)).permute(0, 1, 3, 4, 2).contiguous() # (b, c, h, w/2+1, 2) ffted = torch.complex(ffted[..., 0], ffted[..., 1]) # (b, c, h, w/2+1) ifft_shape_slice = x.shape[-2:] output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) # (b, c, h, w) ### I use "ortho" normalization through my model. output = output * torch.sqrt(torch.tensor(h * w, requires_grad=False)) # add this line return output ```After the 3 steps, I successfully export my dynamic model, but it is very slowly using onnxruntime for inference when execute
ort.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
. I don't know how to handle this, here is the visualization of my onnx file:visualization:
![rfftn-irfftn onnx](https://github.com/onnx/onnx/assets/73473905/15e5d386-48df-40f6-82e1-c385c22a3614)It has a big subgraph in it due to
torch._dynamo
? Maybe it's the reason why onnxruntime loading the onnx file so slowly? Would anyone give some help?