i404788 / s5-pytorch

Pytorch implementation of Simplified Structured State-Spaces for Sequence Modeling (S5)
Mozilla Public License 2.0
59 stars 3 forks source link

RuntimeError with complex parameter type, Adam and Weight Decay #2

Open Snagnar opened 1 year ago

Snagnar commented 1 year ago

Hi,

I was trying to use the S5Block with an Adam optimizer with weight decay. However, I got a strange bug, that the sizes of parameters and gradients mismatch. The error only occures with cuda tensors/model and only when weight_decay is enabled. Below a minimal script that reproduces the bug:

from s5 import S5Block
import torch

x = torch.randn(16, 64, 256).cuda()
a = S5Block(256, 128, False).cuda()
a.train()
# h = torch.optim.Adam(a.parameters(), lr=0.001)  # this works
h = torch.optim.Adam(a.parameters(), lr=0.001, weight_decay=0.0001)  # this doesn't work

out = a(x.cuda())
out.sum().backward()
h.step()

After a lot of digging I found the part that caused the error: complex data type handling of device parameters is faulty in the _multi_tensor_adam in the newest version 2.0.1 of pytorch. Specifically in L. 442 in torch/optim/adam.py was a wrong variable used for computing the weight decay.

However, this seems to have been fixed since May 9 with this commit. So with a newer pytorch version this should be working. Right now, this remains broken.

Just posting this here in case anyone else is having this issue.