Open JCBrouwer opened 3 weeks ago
I think it's just a question of adding @custom_fwd and @custom_bwd to the ParallelExperts autograd Function as explained here: https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
did you solve the problem? i am facing similar issues
I added the the custom_fwd/bwd decorators to the ParallelExperts class like this:
...
from torch.amp import custom_fwd, custom_bwd
class ParallelLinear(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda")
def forward(
...
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_out):
...
Not sure if this is a generic solution, but it works on my end.
I'm getting a couple of dtype-related errors when using the MLP module in a torch.autocast block. Here's my simple wrapper of the MLP module:
If I add
@torch.autocast(device_type='cuda', dtype=torch.bfloat16)
to the forward method I get the following type mismatch on the linear layer directly after MyMLP:If I put my whole loss function in an autocast block I get this issue later in the backwards pass: