state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.2k stars 1.12k forks source link

ImportError: /home/ubuntu/.local/lib/python3.10/site-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorESt8optionalIN3c1010ScalarTypeEES5_INS6_6LayoutEES5_INS6_6DeviceEES5_IbES5_INS6_12MemoryFormatEE #576

Closed turian closed 1 month ago

turian commented 1 month ago

I provision a lambalabs H100 machine, and run:

git clone https://github.com/state-spaces/mamba.git
cd mamba
pip install .

Seems to work fine but then when I try to use:

ubuntu@209-20-156-240:~/mamba$ ipython3
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 7.31.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch
   ...: from mamba_ssm import Mamba
   ...:
   ...: batch, length, dim = 2, 64, 16
   ...: x = torch.randn(batch, length, dim).to("cuda")
   ...: model = Mamba(
   ...:     # This module uses roughly 3 * expand * d_model^2 parameters
   ...:     d_model=dim, # Model dimension d_model
   ...:     d_state=16,  # SSM state expansion factor
   ...:     d_conv=4,    # Local convolution width
   ...:     expand=2,    # Block expansion factor
   ...: ).to("cuda")
   ...: y = model(x)
   ...: assert y.shape == x.shape
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-1-e160b7215ca2> in <module>
      1 import torch
----> 2 from mamba_ssm import Mamba
      3
      4 batch, length, dim = 2, 64, 16
      5 x = torch.randn(batch, length, dim).to("cuda")

~/mamba/mamba_ssm/__init__.py in <module>
      1 __version__ = "2.2.2"
      2
----> 3 from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
      4 from mamba_ssm.modules.mamba_simple import Mamba
      5 from mamba_ssm.modules.mamba2 import Mamba2

~/mamba/mamba_ssm/ops/selective_scan_interface.py in <module>
     14     causal_conv1d_cuda = None
     15
---> 16 import selective_scan_cuda
     17
     18

ImportError: /home/ubuntu/.local/lib/python3.10/site-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorESt8optionalIN3c1010ScalarTypeEES5_INS6_6LayoutEES5_INS6_6DeviceEES5_IbES5_INS6_12MemoryFormatEE
karannb commented 1 month ago

This has happened with me before, might be because the CUDA version on the machine doesn't match the CUDA version of torch you have installed. Can you run

nvidia-smi

and see the CUDA version, And then run

pip list | grep torch

and check if the versions match?

turian commented 1 month ago
ubuntu@209-20-157-222:~$ nvidia-smi
Sat Sep 28 09:26:12 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 PCIe               On  | 00000000:08:00.0 Off |                    0 |
| N/A   31C    P0              47W / 350W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
ubuntu@209-20-157-222:~$ pip  list | grep torch
torch                     2.0.1

is there a way to address?

karannb commented 1 month ago

Okay great, it seems like you don't have CUDA-enabled torch installation, on my machine, when I run

pip list | grep torch

I get -

torch                    2.3.0+cu121
torchaudio               2.3.0+cu121
torchvision              0.18.0+cu121

For your CUDA version (12.2) I would suggest trying -

pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124

Or check here for other versions, I am not sure which of 12.1 or 12.4 will work with 12.2.

turian commented 1 month ago

Thank you, pip install -U torch torchvision torchaudio fixed