Closed jonashaag closed 3 years ago
Indeed we have those 3 representations and I guess it's not very clear to newcomers.
Just to clarify, Asteroid's complex format was designed so that complex filterbanks could be used as real filterbanks, using the same interface as nn.Conv1d
. We didn't intend to create a self-contained complex format.
Few points:
IMO, PyTorch team knows that complex number support is important and they'll invest the time and energy to make it right. Long term, using native complex will be the easiest.
We'll keep Asteroid's complex format for filterbanks for a while because it's practical to have them be interchangeable with real ones. When experiments are doable with complex NNs, we'll see what we do, maybe have a as_complex
argument in the complex filterbanks or IDK.
Torchaudio's format is quite practical because we don't need the axis
information but I don't know if we need it in Asteroid.
There is a fourth way of representing complex numbers here where matrix product are supported, as well as solve and other ops.
What would you like to support in this wrapper? All what is in complex_nn.py
?
Maybe @anjali411, @boeddeker would like to barge in on this?
The wrapper would have a similar/identical interface to Torch's complex numbers, maybe with a few more operations for performance, and with functions to convert between the representations.
class AnyComplex:
def real(self): ...
def imag(self): ...
def abs(self): ...
def angle(self): ...
def conj(self): ...
...
# Maybe these could also live outside the class hierarchy.
def as_torchaudio_complex(self) -> Tensor[..., 2]: ...
def as_asteroid_complex(self) -> Tensor[..., 2 * n, ...]: ...
def as_torch_complex(self) -> Tensor[...]: ...
@classmethod
def from_torchaudio_complex(self, t: Tensor): ...
@classmethod
def from_asteroid_complex(self, t: Tensor): ...
@classmethod
def from_torch_complex(self, t: Tensor): ...
class TorchComplex(AnyComplex):
def real(self):
return self.tensor.real
...
class AsteroidComplex(AnyComplex):
def real(self):
return self.tensor[..., :n//2, ...]
...
class TorchaudioComplex(AnyComplex):
def real(self):
return self.tensor[..., 0]
...
Yes, it would be practical to have that. I don't know where it would make sense to have this though.
@jonashaag Complex Tensors are in the beta stage right now but a lot of linear algebra and autograd support has been added in the last few months and there's more support incoming.
There’s native support for only ComplexFloat
and ComplexDouble
dtypes. We would like to add support for ComplexHalf
, ComplexBFloat16
in future but it’s not something we are actively working on for the next two releases. However, that is subject to change in case there’s a growing demand for these dtypes. If you are working with complex data represented as a float or double tensor, it might be a good idea to try the complex native support for the following reasons:
There are plenty of numpy like functionality already supported for complex tensors. complex
constructor, real, imag, angle, abs. We have also added torch.sgn
to get complex sign and polar
to construct complex tensors using angle and abs tensors.
You wouldn’t have to write custom functions (like matmul, svd, etc. ) for a complex operator you’d like to use. If there’s an operator you’d like for complex tensors and it isn’t already supported, please file an issue on pytorch and add complex
tag. You can checkout some of the already added lining operators for complex tensors here: https://github.com/pytorch/pytorch/issues/33152#issue-562792788
Complex Operations are vectorized.
If there’s any reason you’d like to switch to the real representation of complex tensors (…, 2), you can easily use the view function (O(1)) view_as_real
.
3. There’s autograd support for complex tensors: https://pytorch.org/docs/master/notes/autograd.html#what-are-complex-derivatives
There’s also some newly added distributed support for complex tensors. Check out https://github.com/pytorch/pytorch/issues/45760
Also please get involved in complex - torch.nn
module discussion here: https://github.com/pytorch/pytorch/issues/46374 if there's something you'd like to be supported or if you have any ideas.
Thanks for your answer and your efforts on complex support in PyTorch!
I think Jonas already spent time with the complex native support in complex_nn.py and in DCCRN and DCUNet (see in asteroid/masknn
). He can probably detail more than me, but training didn't seem stable for now (using nightly).
Thanks for the links to the discussions, and also don't hesitate to ping me if you need feedback on specific things.
Thanks for the details @anjali411!
I think BF16 and TF32 support alone would be worth changing to real valued backend at least for these models. (FP16 does not work well with some of the Asteroid models, probably needs hand tuning for grad stability in some places.)
At the moment none of these can be used with PyTorch anyways, so no point switching, but let’s see when we can test it how well they work.
Should we do this or drop this idea?
I don’t know... maybe I will revisit this when/if TF32 or BF16 support have landed in PyTorch, and I get access to a GPU that supports it, for faster training. (FP16 training wasn’t successful in any of my attempts)
We have 3 different types of complex numbers in Asteroid: Torch "native" complex, Torchaudio complex (tensors of shape [..., 2]) and Asteroid complex (tensors of shape [..., 2 * n, ...]).
The native complex tensors have a very nice API with things like
.angle()
, but judging from my personal experience and from the number of bug reports, it is still much less stable for training than the other two representations (which are FP32 based). Another drawback that I realised recently that it doesn't allow for FP16, BF16, TF32, etc. training. I am not sure if this is a limitation of the way PyTorch represents these numbers in memory, or simply not yet implemented.I think there are three ways to deal with the situation:
frustrationedit: patience until we are there.Interested to hear your thoughts!