ivannz / cplxmodule

Complex-valued neural networks for pytorch and Variational Dropout for real and complex layers.
MIT License
138 stars 27 forks source link

Feature request : compatibility with einops #27

Closed pfeatherstone closed 1 year ago

pfeatherstone commented 1 year ago

I would like the following to work:

from einops import rearrange
from cplxmodule.nn import CplxToCplx
from cplxmodule.nn.modules.casting import ConcatenatedRealToCplx
cplxrearrange = CplxToCplx[rearrange]

x = torch.randn(4, 1024, 2)
y = ConcatenatedRealToCplx()(x)
z = cplxrearrange()(y, 'b t c -> b c t')

I thought it would be possible composing CplxToCplx and rearrange. But i get some long complicated errors. Am I mis-using something or is CplxToCplx not supposed to work for any type of function.

pfeatherstone commented 1 year ago

The following works:

def cplxrearrange(x, pattern, **axes_lengths):
    return Cplx(rearrange(x.real, pattern, **axes_lengths),
                rearrange(x.imag, pattern, **axes_lengths))
pfeatherstone commented 1 year ago

This works!

from einops.layers.torch import Rearrange
CplxRearrange = CplxToCplx[Rearrange]