mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.49k stars 835 forks source link

[BUG: AssertionError: Mamba is not installed. Please install it using `pip install mamba-ssm`. #192

Open matbee-eth opened 1 month ago

matbee-eth commented 1 month ago

Python -VV

(codestral) ➜  dev python -VV
Python 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]

Pip Freeze

(codestral) ➜  dev pip freeze
absl-py==2.1.0
addict==2.4.0
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
apex @ file:///home/acidhax/dev/VisionLLM/VisionLLMv2/apex
attrs==23.2.0
beautifulsoup4==4.12.3
black==24.4.2
Brotli @ file:///croot/brotli-split_1714483155106/work
causal-conv1d==1.4.0
certifi @ file:///croot/certifi_1720453481653/work/certifi
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
cloudpickle==3.0.0
crowdposetools==2.0
DCNv3==1.0
detectron2 @ git+https://github.com/facebookresearch/detectron2.git@e8806d607403cf0f2634d4c5ac464109fdc7d4af
docstring_parser==0.16
einops==0.8.0
filelock @ file:///croot/filelock_1700591183607/work
fire==0.6.0
fsspec==2024.6.1
fvcore==0.1.5.post20221221
gdown==5.2.0
gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
grpcio==1.64.1
huggingface-hub==0.23.5
hydra-core==1.3.2
idna @ file:///croot/idna_1714398848350/work
importlib_metadata==7.1.0
iopath==0.1.9
Jinja2 @ file:///croot/jinja2_1716993405101/work
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
mamba-ssm==2.2.2
MarkupSafe @ file:///croot/markupsafe_1704205993651/work
mistral_common==1.2.1
mistral_inference==1.2.0
mkl-fft @ file:///croot/mkl_fft_1695058164594/work
mkl-random @ file:///croot/mkl_random_1695059800811/work
mkl-service==2.4.0
mpmath @ file:///croot/mpmath_1690848262763/work
MultiScaleDeformableAttention==1.0
mypy-extensions==1.0.0
networkx @ file:///croot/networkx_1720002482208/work
ninja==1.11.1.1
numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.82
nvidia-nvtx-cu12==12.1.105
packaging==24.1
pathspec==0.12.1
pillow @ file:///croot/pillow_1721059439630/work
platformdirs==4.2.2
portalocker==2.8.2
pycocotools @ git+https://github.com/youtubevos/cocoapi.git@f24b5f58594adfe4f4c015bf49dbc819cc3be98f#subdirectory=PythonAPI
pydantic==2.6.1
pydantic_core==2.16.2
PySocks==1.7.1
PyYAML @ file:///croot/pyyaml_1698096049011/work
referencing==0.35.1
regex==2024.5.15
requests @ file:///croot/requests_1716902831423/work
rpds-py==0.19.0
safetensors==0.4.3
sentencepiece==0.1.99
simple_parsing==0.1.5
six==1.16.0
soupsieve==2.5
sympy @ file:///croot/sympy_1701397643339/work
tensorboard-data-server==0.7.2
termcolor==2.4.0
tokenizers==0.19.1
tomli==2.0.1
torch==2.3.1
torchaudio==2.3.1
torchvision==0.18.1
tqdm==4.66.4
transformers==4.42.4
triton==2.3.1
typing_extensions @ file:///croot/typing_extensions_1715268824938/work
urllib3 @ file:///croot/urllib3_1718912636303/work
Werkzeug==3.0.3
xformers==0.0.27
yacs==0.1.8
yapf==0.40.2
zipp==3.19.2

Reproduction Steps

(codestral) ➜  dev pip install mamba-ssm
Requirement already satisfied: mamba-ssm in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (2.2.2)
Requirement already satisfied: torch in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from mamba-ssm) (2.3.1)
Requirement already satisfied: packaging in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from mamba-ssm) (24.1)
Requirement already satisfied: ninja in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from mamba-ssm) (1.11.1.1)
Requirement already satisfied: einops in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from mamba-ssm) (0.8.0)
Requirement already satisfied: triton in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from mamba-ssm) (2.3.1)
Requirement already satisfied: transformers in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from mamba-ssm) (4.42.4)
Requirement already satisfied: filelock in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from torch->mamba-ssm) (3.13.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from torch->mamba-ssm) (4.11.0)
Requirement already satisfied: sympy in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from torch->mamba-ssm) (1.12)
Requirement already satisfied: networkx in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from torch->mamba-ssm) (3.3)
Requirement already satisfied: jinja2 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from torch->mamba-ssm) (3.1.4)
Requirement already satisfied: fsspec in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from torch->mamba-ssm) (2024.6.1)
Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from transformers->mamba-ssm) (0.23.5)
Requirement already satisfied: numpy<2.0,>=1.17 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from transformers->mamba-ssm) (1.26.4)
Requirement already satisfied: pyyaml>=5.1 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from transformers->mamba-ssm) (6.0.1)
Requirement already satisfied: regex!=2019.12.17 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from transformers->mamba-ssm) (2024.5.15)
Requirement already satisfied: requests in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from transformers->mamba-ssm) (2.32.2)
Requirement already satisfied: safetensors>=0.4.1 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from transformers->mamba-ssm) (0.4.3)
Requirement already satisfied: tokenizers<0.20,>=0.19 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from transformers->mamba-ssm) (0.19.1)
Requirement already satisfied: tqdm>=4.27 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from transformers->mamba-ssm) (4.66.4)
Requirement already satisfied: MarkupSafe>=2.0 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from jinja2->torch->mamba-ssm) (2.1.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from requests->transformers->mamba-ssm) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from requests->transformers->mamba-ssm) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from requests->transformers->mamba-ssm) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from requests->transformers->mamba-ssm) (2024.7.4)
Requirement already satisfied: mpmath>=0.19 in /mnt/environments/miniconda3/envs/codestral/lib/python3.10/site-packages (from sympy->torch->mamba-ssm) (1.3.0)

(codestral) ➜  dev mistral-chat $HOME/7B_MAMBA_CODE --instruct --max_tokens 256

Traceback (most recent call last):
  File "/home/acidhax/miniconda3/envs/codestral/bin/mistral-chat", line 8, in <module>
    sys.exit(mistral_chat())
  File "/home/acidhax/miniconda3/envs/codestral/lib/python3.10/site-packages/mistral_inference/main.py", line 201, in mistral_chat
    fire.Fire(interactive)
  File "/home/acidhax/miniconda3/envs/codestral/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/acidhax/miniconda3/envs/codestral/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/acidhax/miniconda3/envs/codestral/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/acidhax/miniconda3/envs/codestral/lib/python3.10/site-packages/mistral_inference/main.py", line 81, in interactive
    model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)
  File "/home/acidhax/miniconda3/envs/codestral/lib/python3.10/site-packages/mistral_inference/mamba.py", line 75, in from_folder
    model = Mamba(model_args)
  File "/home/acidhax/miniconda3/envs/codestral/lib/python3.10/site-packages/mistral_inference/mamba.py", line 27, in __init__
    assert _is_mamba_installed, "Mamba is not installed. Please install it using `pip install mamba-ssm`."
AssertionError: Mamba is not installed. Please install it using `pip install mamba-ssm`.

Expected Behavior

working

Additional Context

No response

Suggested Solutions

No response

randoentity commented 1 month ago

Same here. Definitely something with the conda env. Probably need to install some cuda nvidia nvcc lib.

ImportError: ...envs/mamba/lib/python3.12/site-packages/selective_scan_cuda.cpython-312-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi

Possible solutions: https://github.com/state-spaces/mamba/issues/169#issuecomment-2099147220

Okay so it looks like we need torch < 2.3.0 (so 2.2.2) but xformers 0.0.27 needs >=2.3.0 and mistral-inference doesn't work (for me, mamba deps of link above) with 2.3.0 but doesn't work with the xformers binaries <0.0.27. I get other errors (RuntimeError: Couldn't instantiate class <class 'mistral_inference.model.ModelArgs'> using init args dict_keys(['dim', 'n_layers', 'vocab_size']): ModelArgs.__init__() missing 5 required positional arguments: 'head_dim', 'hidden_dim', 'n_heads', 'n_kv_heads', and 'norm_eps') when running mistral-inference < 1.2.0. I'm trying to build xformers from source now and shot myself in the foot by not specifying MAX_JOBS with ninja (on a 12 thread CPU with 128GB, 12 jobs saturates 128GB without anything else, not even a DE, running).

Now at: AttributeError: 'Mamba2' object has no attribute 'dconv'. Did you mean: 'd_conv'?

https://github.com/state-spaces/mamba/issues/452

patrickvonplaten commented 1 month ago

Hmm not sure what's going on. Here a google colab showing how to install: https://colab.research.google.com/drive/1aHH4PW4eBU_R4R8pQ9BuYeOeMTiA98NF?usp=sharing

Does that work for you?

matbee-eth commented 1 month ago

I got it to work with using conda's cuda 11.8 and python 3.9

Seems like cuda 12.1 has problems with this project

patrickvonplaten commented 1 month ago

Hmm it also works for CUDA 12.1 for me. It's most likely linked to the mamba_ssm install - maybe check here: https://github.com/state-spaces/mamba?tab=readme-ov-file#installation

randoentity commented 1 month ago

Working now for me too. Default cuda 12.1. Python 3.10.14 was giving me problems. This works with a fresh conda env:

conda install python==3.10.13 pip pip install mistral_inference>=1 mamba-ssm causal-conv1d

chvipdata commented 1 month ago

Working now for me too. Default cuda 12.1. Python 3.10.14 was giving me problems. This works with a fresh conda env:

conda install python==3.10.13 pip pip install mistral_inference>=1 mamba-ssm causal-conv1d

It does not work for me.

xNul commented 1 month ago

Working now for me too. Default cuda 12.1. Python 3.10.14 was giving me problems. This works with a fresh conda env:

conda install python==3.10.13 pip pip install mistral_inference>=1 mamba-ssm causal-conv1d

Thank you! This was my issue. There seems to be some problem with Python 3.10.14, or changing the version to 3.10.13 triggered a fix.

I'm using WSL Ubuntu 22.04 and installed CUDA with these commands (via this):

wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-keyring_1.1-1_all.debsudo
dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get -y install cuda-toolkit-12-5

If you have an existing cudatoolkit install, you should uninstall it before trying to install this cudatoolkit package.

You may have to add nvcc to your path afterwards (export PATH=$PATH:/usr/local/cuda/bin). I did in my case, but I installed CUDA awhile before actually fixing the issue so it may not be necessary.

My conda env install commands:

rm -rf ~/.cache/conda ~/.cache/pip ~/.cache/huggingface # fixes the selective_scan_cuda ImportError for me
conda create -n tempenv python=3.10.13
conda activate tempenv
pip install mistral-inference>=1.2
pip install packaging # cannot install this package inline with the following pip install command
pip install mamba-ssm causal-conv1d transformers

With those commands, I end up with these versions (same as the collab notebook):

and running mamba-chat with:

mistral-chat $HOME/mistral_models/mamba-codestral-7B-v0.1 --instruct --max_tokens 256

works for me. I downloaded Codestral Mamba using the Hugging Face snapshot download Python code.

updateforever commented 1 week ago

_import selective_scan_cuda ImportError: $HOME/lib/python3.10/site-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorEN3c108optionalINS5_10ScalarTypeEEENS6_INS5_6LayoutEEENS6_INS5_6DeviceEEENS6_IbEENS6_INS512MemoryFormatEEE

I think it might be a version mismatch issue. I modified it to directly insert the relevant library using mamba:

_is_mamba_installed = False
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
# try:
#     from mamba_ssm.models.config_mamba import MambaConfig
#     from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

#     _is_mamba_installed = True
# except ImportError:
#     _is_mamba_installed = False

and it immediately showed the error mentioned above. My environment is torch 2.1 with Python 3.10.