erksch / fnet-pytorch

Unofficial PyTorch implementation of Google's FNet: Mixing Tokens with Fourier Transforms. With checkpoints.
MIT License
67 stars 7 forks source link

Disable amp autocast in FourierFFTLayer #10

Closed iceychris closed 3 years ago

iceychris commented 3 years ago

Hey!

Currently, torch.fft.fft with dtype torch.float16 and complex tensors is not supported.

As a workaround, I've disabled autocast for FourierFFTLayer:

class FourierFFTLayer(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.cuda.amp.autocast(enabled=False)
    def forward(self, hidden_states):
        return torch.fft.fft(torch.fft.fft(hidden_states.float(), dim=-1), dim=-2).real

We convert back to float32 (using .float()) here in case autocast is enabled in the parent scope, where hidden_states might have a different dtype.

Example

This snippet now works as expected:

import torch
from fnet import FNet

m = FNet(...).cuda().train()
N, T = 128, 128
inp = (torch.randint(0, 256, (N, T,)).cuda(), torch.randint(0, 4, (N, T,)).cuda())

# forward pass
with torch.cuda.amp.autocast():
    h, _ = m(*inp)

# backward pass
h.mean().backward()

Gotchas