databricks / megablocks

Apache License 2.0
1.19k stars 172 forks source link

what devices are supported? #155

Open Guodanding opened 6 days ago

Guodanding commented 6 days ago

hello, i have tried to use megablocks in V100 + pytorch2.4.0+cu121, but get error with "cannot support bf16". If i use megablocks in fp32, i get error "group gemm must use bf16". So i change my enviroment to A6000, and then get error:

[rank1]:   File "/usr/local/lib/python3.11/dist-packages/megablocks/ops/sort.py", line 34, in forward
[rank1]:     ops.sort(x, end_bit, x_out, iota_out)
[rank1]: RuntimeError: no kernel image is available for execution on the device

So i am wondering what devices are supported? Only A100/A800 and H100/H800 ?

mvpatel2000 commented 6 days ago

We have only tested on A100/H100 at this time. It is possible with some tweaks to likely get it working on other hardware

Guodanding commented 6 days ago

ok, thanks!I will check it.

Guodanding commented 6 days ago

After i change my env to A100,the issue still exists. I have no idea how to do.

[rank0]:   File "/usr/local/lib/python3.11/dist-packages/megablocks/layers/moe.py", line 468, in forward
[rank0]:     out = self.experts(x, scores, expert_weights, top_experts)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/megablocks/layers/moe.py", line 429, in forward
[rank0]:     x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/megablocks/layers/moe.py", line 262, in parallel_forward_once
[rank0]:     indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
[rank0]:                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/megablocks/layers/moe.py", line 161, in indices_and_bins
[rank0]:     output = ops.sort(top_expert, self.sort_end_bit)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 574, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/megablocks/ops/sort.py", line 34, in forward
[rank0]:     ops.sort(x, end_bit, x_out, iota_out)
[rank0]: RuntimeError: no kernel image is available for execution on the device

this is my env: nv

root@I1cedf606ff00701a02:~/workspace/DynamicMoE# nvidia-smi
Wed Oct  9 11:16:58 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-PCIE-40GB          On  | 00000000:40:00.0 Off |                    0 |
| N/A   34C    P0              36W / 250W |      4MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

pytorch:

Name: torch
Version: 2.4.0+cu121
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /usr/local/lib/python3.11/dist-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: megablocks, stanford-stk, torchaudio, torchvision

and i have downloaded the group-gemm utils:

pip install git+https://github.com/tgale96/grouped_gemm@main
Guodanding commented 6 days ago

After I reinstall megablock from

pip install megablocks[all]

instead of installing form source(what i have done), the issue disappears.

mvpatel2000 commented 5 days ago

Hm... maybe something is wrong with grouped gemm wheels...? @tgale96 any ideas here

Guodanding commented 4 days ago

ok! And what is the differences of "grouped" impl and "sparse" impl? Which one gains better throughput?

mvpatel2000 commented 3 days ago

They are different implementations for a groupedmatmul. Grouped is a series of iterated kernel calls with standard matmuls, whereas sparse is the strategy described in MegaBlocks https://arxiv.org/abs/2211.15841

Guodanding commented 3 days ago

But why recommend using grouped one instead of sparse one for H-GPUs?

Installing megablocks[gg] enables dMoE computation with grouped GEMM. 
This feature is enabled by setting the mlp_impl argument to grouped. 
This is currently our recommended path for Hopper-generation GPUs.
mvpatel2000 commented 2 days ago

We have observed better performance on H100s with grouped-gemm. This may not be true with the latest versions of triton

Guodanding commented 2 days ago

Thanks!

Guodanding commented 2 days ago

There is another issue. i haved tried dMoE fwd, which pass in my 1xA100 device but stop(not break) in my 2xA6000 device. Both are launch by

torchrun --standalone --nnodes=1 --nproc-per-node=[1 or 2] [my python script]

And this is my python script:

from megablocks import dMoE
from megablocks import Arguments

import os

import torch
import torch.distributed as dist

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
print(f"Running example on {rank=} in a world with {world_size=}")
dist.init_process_group("nccl")
torch.cuda.set_device(rank)
num_node_devices = torch.cuda.device_count()
rank_lists = list(range(0, num_node_devices))
ep_groups = dist.new_group(rank_lists)
device = f'cuda:{rank}'

args = Arguments()
args.moe_expert_model_parallelism = True
args.expert_parallel_group  = ep_groups
args.device = device
args.bf16 = True
args.fp16 = False
args.mlp_impl = 'grouped'

model = dMoE(args)
model.to(device)

print("build up model success!")

x = torch.rand((2, 4, 1024), dtype=torch.float32).to(torch.bfloat16)
x = x.to(device)

torch.manual_seed(42)
torch.cuda.manual_seed(42)
logits = model(x)

if rank == 0:
    print("didn't crash yay!")

I try to debug, and find the the code process stop there:

# megablocks/layers/moe.py --> class ParallelMLP --> func parallel_forward_once --> lines 288

        # This view updates the shape of the tensor from [sl, bs, hs] to
        # [sl * bs, hs] prior to the permutation.
        x = x.view(-1, x.shape[-1])
        print("pass!")
        output = ops.gather(x, indices, bin_ids, bins, self.top_k)
        print("break!")
        assert output is not None
        x = output

the output is

Running example on rank=1 in a world with world_size=2
Running example on rank=0 in a world with world_size=2
build up model success!
build up model success!
pass!
pass!

Is there anything wrong with ops.gather? Or more information needed?

Guodanding commented 2 days ago

OK, i have tried it in 2xA100 device and the issue disappears. It seems that some tweaks are needed to adapt it to devices other than A100/H100.

mvpatel2000 commented 1 day ago

Hm... I'm not super sure what happens if you are on one GPU -- it might be some error with torch dist initialization that is buried. I unfortunately do not have other GPUs to test on, but if you manage to narrow it down happy to help fix.

Guodanding commented 1 day ago

ok, i will try it.