shawntan / scattermoe

Triton-based implementation of Sparse Mixture of Experts.
Apache License 2.0
184 stars 14 forks source link

torch.autocast errors #17

Open JCBrouwer opened 3 weeks ago

JCBrouwer commented 3 weeks ago

I'm getting a couple of dtype-related errors when using the MLP module in a torch.autocast block. Here's my simple wrapper of the MLP module:

from scattermoe.mlp import MLP as MoE

class MyMLP(nn.Module):
    def __init__(self, n_experts: int, d_model: int, mlp_ratio: int = 4, d_out: int | None = None) -> None:
        super().__init__()
        self.moe = MoE(
            input_size=d_model, hidden_size=d_model * mlp_ratio, num_experts=n_experts, top_k=1, activation=ReLUSquare()
        )
        if d_out is not None:
            self.out = nn.Linear(d_model, d_out)
        else:
            self.out = nn.Identity()

    def forward(self, x: Tensor, e: LongTensor) -> Tensor:
        v = self.moe.forward(
            x, expert_p=torch.ones_like(e, dtype=x.dtype).unsqueeze(1), expert_idxs=e.unsqueeze(1)
        )
        v = self.out(v)
        return v

If I add @torch.autocast(device_type='cuda', dtype=torch.bfloat16) to the forward method I get the following type mismatch on the linear layer directly after MyMLP:

Traceback (most recent call last):
...
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float

If I put my whole loss function in an autocast block I get this issue later in the backwards pass:

Traceback (most recent call last):
...
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply
    return user_fn(self, *args)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/scattermoe/parallel_experts.py", line 55, in backward
    d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
RuntimeError: expected scalar type BFloat16 but found Float
JCBrouwer commented 3 weeks ago

I think it's just a question of adding @custom_fwd and @custom_bwd to the ParallelExperts autograd Function as explained here: https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops

uldyssian2008 commented 2 weeks ago

did you solve the problem? i am facing similar issues

JCBrouwer commented 2 weeks ago

I added the the custom_fwd/bwd decorators to the ParallelExperts class like this:

...
from torch.amp import custom_fwd, custom_bwd

class ParallelLinear(torch.autograd.Function):
    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(
...
    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_out):
...

Not sure if this is a generic solution, but it works on my end.