HazyResearch / flash-fft-conv

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores
Apache License 2.0
277 stars 27 forks source link

Is FlashDepthWiseConv1d incompatible with torch.autocast? #3

Closed manuel-tran closed 11 months ago

manuel-tran commented 11 months ago

Dear Hazy Research Team,

Thanks for releasing flash-ffr-conv. I was eagerly awaiting the announcement after reading the Hyena paper. While implementing FlashDepthWiseConv1d, I noticed that it only works when both inputs and weights are half-precision (RuntimeError: u must be float16 or bfloat16 and RuntimeError: weight must be float16 or bfloat16).

Although converting both to FP16 is not an issue when training without torch.autocast, it is a problem in mixed-precision training because autocast cannot handle scaling FP16 gradients (ValueError: Attempting to unscale FP16 gradients). Is there a way to use FlashDepthWiseConv1d in a Hyena model with mixed-precision training?

Here is a minimal example. Thank you very much!

from flashfftconv import FlashDepthWiseConv1d

conv1d_torch = nn.Conv1d(
    in_channels=512*3,
    out_channels=512*3,
    kernel_size=3,
    groups=512*3,
    padding=2,
    dtype=torch.float16
).cuda()

flash_conv1d = FlashDepthWiseConv1d(
    channels=512*3,
    kernel_size=3,
    padding=1,
    weights=conv1d_torch.weight,
    bias=conv1d_torch.bias,
    dtype=torch.float16
).cuda()

x = torch.rand(1, 1536, 2048, requires_grad=True).cuda()
y = torch.rand(1, 1536, 2048, requires_grad=True).cuda()

out_torch = conv1d_torch(x.half()) 
out_flash = flash_conv1d(x.half())

criterion = nn.MSELoss().cuda()
optimizer = torch.optim.AdamW(flash_conv1d.parameters())
scaler = torch.cuda.amp.GradScaler()

with torch.autocast(device_type='cuda', dtype=torch.float16):
    optimizer.zero_grad()
    logits = flash_conv1d(x.half())
    loss = criterion(logits, y)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
DanFu09 commented 11 months ago

Hi, thanks for your interest, and great question!

We currently only support fp16/bf16 inputs and weights. We'll be bringing in fp32 support everywhere soon (including the autocast version, with fp32 weights and fp16/bf16 inputs). Will update this issue when that's in.

DanFu09 commented 11 months ago

Hi, we've just pushed a commit that should fix this, as of commit afceac4. Can you give it a try and see if it works for your pipelines now?

manuel-tran commented 11 months ago

Hi, thanks for the update. I tried the new feature and it works!