Previously if ffn_hidden_size was 128 and top k was equal to the number of experts, the output of nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) would be something like torch.Tensor(1) instead of torch.Tensor([1]) -- a zero dimensional tensor instead of a one dimensional tensor. This was causing an error during concatenation on the next line:
File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 148, in sparse_forward_once
topo = self.topology(x, padded_bins)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 98, in topology
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 56, in sparse_transpose
offsets_t = torch.cat([zero, nnz_per_column])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: zero-dimensional tensor (at position 1) cannot be concatenated
To address the bug, we simply make nnz_per_column a 1D tensor if it's 0D. I added a new set of parameters to the dmoe tests that fails without this change and succeeds with the change. I successfully ran the llm foundry torch_dmoe vs mb_dmoe tests to verify correctness of this change as well.
The second change is to have better error messages for invalid ffn_hidden_size values to help external users.
You can reproduce this error with the small script below as well:
Previously if
ffn_hidden_size
was 128 and top k was equal to the number of experts, the output ofnnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
would be something liketorch.Tensor(1)
instead oftorch.Tensor([1])
-- a zero dimensional tensor instead of a one dimensional tensor. This was causing an error during concatenation on the next line:To address the bug, we simply make
nnz_per_column
a 1D tensor if it's 0D. I added a new set of parameters to the dmoe tests that fails without this change and succeeds with the change. I successfully ran the llm foundry torch_dmoe vs mb_dmoe tests to verify correctness of this change as well.The second change is to have better error messages for invalid
ffn_hidden_size
values to help external users.You can reproduce this error with the small script below as well: