asteroid-team / asteroid

The PyTorch-based audio source separation toolkit for researchers
https://asteroid-team.github.io/
MIT License
2.23k stars 421 forks source link

[Discussion] Complex numbers in Asteroid #290

Closed jonashaag closed 3 years ago

jonashaag commented 3 years ago

We have 3 different types of complex numbers in Asteroid: Torch "native" complex, Torchaudio complex (tensors of shape [..., 2]) and Asteroid complex (tensors of shape [..., 2 * n, ...]).

The native complex tensors have a very nice API with things like .angle(), but judging from my personal experience and from the number of bug reports, it is still much less stable for training than the other two representations (which are FP32 based). Another drawback that I realised recently that it doesn't allow for FP16, BF16, TF32, etc. training. I am not sure if this is a limitation of the way PyTorch represents these numbers in memory, or simply not yet implemented.

I think there are three ways to deal with the situation:

  1. Accept the shortcomings of native complex numbers, plus maybe advocate for more attention to complex numbers by the PyTorch team. My opinion: most elegant way but probably costs a lot of time, energy and frustration edit: patience until we are there.
  2. Switch our usage of native complex numbers to one of the other representations, and revisit the decision in a few months. My opinion: Probably the easiest and most robust solution.
  3. Write a wrapper around the three representations and let users choose their preferred "backend" representation. My opinion: Best of all worlds. The wrapper wouldn't be too complex (haha). But maybe out of scope for Asteroid; I am not sure how that code will grow and how much work it is to maintain it.

Interested to hear your thoughts!

mpariente commented 3 years ago

Indeed we have those 3 representations and I guess it's not very clear to newcomers. Just to clarify, Asteroid's complex format was designed so that complex filterbanks could be used as real filterbanks, using the same interface as nn.Conv1d. We didn't intend to create a self-contained complex format.

Few points:

What would you like to support in this wrapper? All what is in complex_nn.py?

Maybe @anjali411, @boeddeker would like to barge in on this?

jonashaag commented 3 years ago

The wrapper would have a similar/identical interface to Torch's complex numbers, maybe with a few more operations for performance, and with functions to convert between the representations.

class AnyComplex:
    def real(self): ...
    def imag(self): ...
    def abs(self): ...
    def angle(self): ...
    def conj(self): ...
    ...

    # Maybe these could also live outside the class hierarchy.
    def as_torchaudio_complex(self) -> Tensor[..., 2]: ...
    def as_asteroid_complex(self) -> Tensor[..., 2 * n, ...]: ...
    def as_torch_complex(self) -> Tensor[...]: ...

    @classmethod
    def from_torchaudio_complex(self, t: Tensor): ...
    @classmethod
    def from_asteroid_complex(self, t: Tensor): ...
    @classmethod
    def from_torch_complex(self, t: Tensor): ...

class TorchComplex(AnyComplex):
    def real(self):
        return self.tensor.real
    ...

class AsteroidComplex(AnyComplex):
    def real(self):
        return self.tensor[..., :n//2, ...]
    ...

class TorchaudioComplex(AnyComplex):
    def real(self):
        return self.tensor[..., 0]
    ...
mpariente commented 3 years ago

Yes, it would be practical to have that. I don't know where it would make sense to have this though.

anjali411 commented 3 years ago

@jonashaag Complex Tensors are in the beta stage right now but a lot of linear algebra and autograd support has been added in the last few months and there's more support incoming.

There’s native support for only ComplexFloat and ComplexDouble dtypes. We would like to add support for ComplexHalf, ComplexBFloat16 in future but it’s not something we are actively working on for the next two releases. However, that is subject to change in case there’s a growing demand for these dtypes. If you are working with complex data represented as a float or double tensor, it might be a good idea to try the complex native support for the following reasons:

Also please get involved in complex - torch.nn module discussion here: https://github.com/pytorch/pytorch/issues/46374 if there's something you'd like to be supported or if you have any ideas.

mpariente commented 3 years ago

Thanks for your answer and your efforts on complex support in PyTorch!

I think Jonas already spent time with the complex native support in complex_nn.py and in DCCRN and DCUNet (see in asteroid/masknn). He can probably detail more than me, but training didn't seem stable for now (using nightly).

Thanks for the links to the discussions, and also don't hesitate to ping me if you need feedback on specific things.

jonashaag commented 3 years ago

Thanks for the details @anjali411!

I think BF16 and TF32 support alone would be worth changing to real valued backend at least for these models. (FP16 does not work well with some of the Asteroid models, probably needs hand tuning for grad stability in some places.)

At the moment none of these can be used with PyTorch anyways, so no point switching, but let’s see when we can test it how well they work.

mpariente commented 3 years ago

Should we do this or drop this idea?

jonashaag commented 3 years ago

I don’t know... maybe I will revisit this when/if TF32 or BF16 support have landed in PyTorch, and I get access to a GPU that supports it, for faster training. (FP16 training wasn’t successful in any of my attempts)