databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

Fix for `ffn_hidden_size` of 128, and better error message for incompatible ffn sizes. #108

Closed snarayan21 closed 1 month ago

snarayan21 commented 1 month ago

Previously if ffn_hidden_size was 128 and top k was equal to the number of experts, the output of nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) would be something like torch.Tensor(1) instead of torch.Tensor([1]) -- a zero dimensional tensor instead of a one dimensional tensor. This was causing an error during concatenation on the next line:

  File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 148, in sparse_forward_once
    topo = self.topology(x, padded_bins)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 98, in topology
    column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
                                                   ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 56, in sparse_transpose
    offsets_t = torch.cat([zero, nnz_per_column])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: zero-dimensional tensor (at position 1) cannot be concatenated

To address the bug, we simply make nnz_per_column a 1D tensor if it's 0D. I added a new set of parameters to the dmoe tests that fails without this change and succeeds with the change. I successfully ran the llm foundry torch_dmoe vs mb_dmoe tests to verify correctness of this change as well.

The second change is to have better error messages for invalid ffn_hidden_size values to help external users.

You can reproduce this error with the small script below as well:

import torch
import pdb
from megablocks.layers.arguments import Arguments
from megablocks.layers.dmoe import ParallelDroplessMLP

args = Arguments(hidden_size = 256, ffn_hidden_size = 128)
pdmlp = ParallelDroplessMLP(args)

x = torch.randn((128, 128)).cuda().to(torch.bfloat16)
expert_weights = torch.randn((128, 1)).cuda().to(torch.bfloat16)
top_experts = torch.zeros((128, 1)).cuda().to(torch.int32)

pdb.set_trace()
topo = pdmlp.sparse_forward_once(x, expert_weights, top_experts)