ivannz / cplxmodule

Complex-valued neural networks for pytorch and Variational Dropout for real and complex layers.
MIT License
138 stars 27 forks source link

cplxmodule.nn.CplxBatchNorm1d is not ONNX exportable #9

Closed pfeatherstone closed 3 years ago

pfeatherstone commented 4 years ago

Steps to reproduce

def _cplxFrom2d_func(x):
    return cplxmodule.Cplx(x[..., 0], x[..., 1])

def _cplxTo2d_func(x):
    return torch.stack([x.real, x.imag], dim=-1)

class _cplxFrom2d(nn.Module):
    def __init__(self):
        super(_cplxFrom2d, self).__init__()

    def forward(self, x):
        return _cplxFrom2d_func(x)

class _cplxTo2d(nn.Module):
    def __init__(self):
        super(_cplxTo2d, self).__init__()

    def forward(self, x):
        return _cplxTo2d_func(x)

model = nn.Sequential(_cplxFrom2d(),
                                    cplxmodule.nn.CplxBatchNorm1d(1),
                                    _cplxTo2d()).eval()

input   = torch.randn(1,1,1024,2)
torch.onnx.export(model,
                                (input,),
                                "file.onnx",
                                opset_version=12,
                                input_names=['in'],
                                output_names=['out'],
                                dynamic_axes={'in': [0,2], 'out': [0]})

I get stack trace:

  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 208, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 409, in _model_to_graph
    _export_onnx_opset_version)
IndexError: index out of range in self
pfeatherstone commented 4 years ago

I'm using pytorch 1.6.0

ivannz commented 4 years ago

thank you for your issue! I will investigate it after the 15th of August.

ivannz commented 4 years ago

I was able to trace the issue down to these two lines 89-90 in batchnorm.py, specifically the jit compiler did not like

    cov_uu, cov_uv = running_cov[0, 0], running_cov[0, 1]
    cov_vu, cov_vv = running_cov[1, 0], running_cov[1, 1]

which is extremely odd to me. Especially, since running_cov here is either None or a 3d tensor of shape 2 x 2 x ..., and thus cannot be indexed outside of its range. I was able to resolve the particular case you provided with the follownig patch:

    cov_uu, cov_uv, cov_vu, cov_vv = running_cov.reshape(4, -1)

which I don't particularly like, since it introduces an unnecessary reshaping and iteration (over dim 0).

Furthermore, it turns out that torch.var operator is not supported by the onnx opset. This can be circumvented by using

    var = (tensor * tensor).mean(dim=axes) + nugget

instead of

    var = tensor.var(dim=axes, unbiased=False) + nugget

This indeed computes the variance, since tensor is centred during calculation of the batch average.

However this sort of quirks seems to me to signify that some understanding is missing on my part, and that many other objects in cplxmodule might need testing for ONNX jit compatibility. I am composing a unit test suite for ONNX compatibility.

pfeatherstone commented 4 years ago

Thank you this is great.

pfeatherstone commented 3 years ago

So applying the following fix works:

cov_uu, cov_uv, cov_vu, cov_vv = running_cov.reshape(4, -1)

Any reason why this can't be committed?

pfeatherstone commented 3 years ago

Sod it, i've added a PR: https://github.com/ivannz/cplxmodule/pull/12

ivannz commented 3 years ago

The issue has been fixed by pr #14