databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

Illegal memory access on non-0 cuda devices from `histogram` #117

Open phillip-kravtsov opened 5 days ago

phillip-kravtsov commented 5 days ago

When the input tensor is not on device 0, histogram causes an illegal memory access which prevents indices_and_bins from being computed correctly on a model & inputs which aren't on device zero.

Reproduction:

import torch
import megablocks

idx = 1
device = f'cuda:{idx}'

test_tensor = torch.tensor([ 0 ], dtype=torch.int64, device=device)
result = megablocks.ops.histogram(test_tensor, 1).cpu()

when run with CUDA_LAUNCH_BLOCKING=1 we get

Traceback (most recent call last):
  File "/home/ubuntu/test_mb.py", line 8, in <module>
    result = megablocks.ops.histogram(test_tensor, 1).cpu()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ubuntu/.local/lib/python3.10/site-packages/megablocks/ops/histogram.py", line 17, in forward
    return ops.histogram(x, max_val)
RuntimeError: an illegal memory access was encountered

whereas when idx is set to 0 the correct values are computed. Quite confused as to how this might be possible.

I'm on megablocks 0.5.1,

numpy==2.0.0
torch==2.3.1
torchaudio==2.3.1
torchvision==0.18.1
triton==2.3.1

Cuda = 12.1 and reproduced on 2 A100's.