Closed iceychris closed 3 years ago
Hey!
Currently, torch.fft.fft with dtype torch.float16 and complex tensors is not supported.
torch.fft.fft
torch.float16
As a workaround, I've disabled autocast for FourierFFTLayer:
autocast
FourierFFTLayer
class FourierFFTLayer(nn.Module): def __init__(self): super().__init__() @torch.cuda.amp.autocast(enabled=False) def forward(self, hidden_states): return torch.fft.fft(torch.fft.fft(hidden_states.float(), dim=-1), dim=-2).real
We convert back to float32 (using .float()) here in case autocast is enabled in the parent scope, where hidden_states might have a different dtype.
.float()
hidden_states
dtype
This snippet now works as expected:
import torch from fnet import FNet m = FNet(...).cuda().train() N, T = 128, 128 inp = (torch.randint(0, 256, (N, T,)).cuda(), torch.randint(0, 4, (N, T,)).cuda()) # forward pass with torch.cuda.amp.autocast(): h, _ = m(*inp) # backward pass h.mean().backward()
torch==1.6.0
Hey!
Currently,
torch.fft.fft
with dtypetorch.float16
and complex tensors is not supported.As a workaround, I've disabled
autocast
forFourierFFTLayer
:We convert back to float32 (using
.float()
) here in caseautocast
is enabled in the parent scope, wherehidden_states
might have a differentdtype
.Example
This snippet now works as expected:
Gotchas
torch==1.6.0
or newer