Open cyrta opened 3 years ago
After investigation, the error while converting is located at OnReIm
def forward(self, x): return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
which calls ` def torch_complex_from_magphase(mag, phase): return torch.view_as_complex( torch.stack((mag torch.cos(phase), mag torch.sin(phase)), dim=-1) )
def torch_complex_from_reim(re, im): return torch.view_as_complex(torch.stack([re, im], dim=-1)) `
also in the astroid_filterbank there is operation
def to_torch_complex(tensor, dim: int = -2): return torch.view_as_complex(to_torchaudio(tensor, dim=dim))
https://github.com/asteroid-team/asteroid-filterbanks/blob/8a3d13fb0e495772bc9d1deac3327affe2833e10/asteroid_filterbanks/transforms.py#L327
The problem is that ONNX doesn't support PyTorch's complex numbers. There isn't a lot we can do except for creating a facade for PyTorch's complex numbers that its based on real numbers. It's not a lot of work to implement, in fact I've implemented it multiple times, but I'm not sure if we should include that code in Asteroid.
I guess for simple operations, it's possible, but when solve and eigenvalue decompositions are computed, having the facade is more complicated, right?
I don't really know what we should do about that.
@jonashaag
It's not a lot of work to implement. [...] It's not a lot of work to implement [...]
can you explain your approach to that ? Maybe show code snipset ?
Are you thinking of dual path or double input size of standard tensors, one for mag and other for phase, going into each module ?
Unfortunately I don't have access to the code for a few days.
The approach I've taken is as follows:
t.abs()
, t.angle()
, t.view_as_complex()
etc calls on complex tensors to complex_nn.abs(t)
etc (and of course implement it there).view_as_complex
to return your preferred complex representation. See module docstring in complex_nn
..permute()
, .transpose()
, slices. Unfortunately that is scattered all over the place in the models so you'll have to make those changes step by step until it works.Thanks a lot, I am debugging some of the models with asteroid complex representation and there are still some errors. I think all of the error places are in complex_nn
.
🐛 Bug
While exporting to onnx some of the models (with complex operation). There is error caused by no support of complex casting in the onnx ops set torch.view_as_complex(input))
To Reproduce
Steps to reproduce the behavior (code sample and stack trace):
Expected behavior
The convertion should proceed without errors and end with proper onnx model.
Environment
Package versions
Run
asteroid-versions
and paste the output here:Additional info
I know it is not set case in pytorch-onnx ops set
ONNX current operations:
issues with errors:
However, we can propose a wrapper that is covering this convertion in such a way that onnx model will be created properly.
view_as_complex is implemented in ATen library
const auto new_strides = computeStrideForViewAsComplex(self.strides()); const auto complex_type = c10::toComplexType(self.scalar_type()); view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);