mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.76k stars 871 forks source link

[BUG] Device error when running on other cuda device than cuda:0 #215

Open cornzz opened 3 months ago

cornzz commented 3 months ago

Python -VV

Python 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]

Pip Freeze

conda env export ```shell name: test channels: - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - asttokens=2.0.5=pyhd3eb1b0_0 - bzip2=1.0.8=h5eee18b_6 - ca-certificates=2024.7.2=h06a4308_0 - comm=0.2.1=py311h06a4308_0 - debugpy=1.6.7=py311h6a678d5_0 - decorator=5.1.1=pyhd3eb1b0_0 - executing=0.8.3=pyhd3eb1b0_0 - ipykernel=6.28.0=py311h06a4308_0 - ipython=8.25.0=py311h06a4308_0 - jedi=0.19.1=py311h06a4308_0 - jupyter_client=8.6.0=py311h06a4308_0 - jupyter_core=5.7.2=py311h06a4308_0 - ld_impl_linux-64=2.38=h1181459_1 - libffi=3.4.4=h6a678d5_1 - libgcc-ng=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libsodium=1.0.18=h7b6447c_0 - libstdcxx-ng=11.2.0=h1234567_1 - libuuid=1.41.5=h5eee18b_0 - matplotlib-inline=0.1.6=py311h06a4308_0 - ncurses=6.4=h6a678d5_0 - nest-asyncio=1.6.0=py311h06a4308_0 - openssl=3.0.14=h5eee18b_0 - packaging=24.1=py311h06a4308_0 - parso=0.8.3=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pip=24.2=py311h06a4308_0 - platformdirs=3.10.0=py311h06a4308_0 - prompt-toolkit=3.0.43=py311h06a4308_0 - prompt_toolkit=3.0.43=hd3eb1b0_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pure_eval=0.2.2=pyhd3eb1b0_0 - pygments=2.15.1=py311h06a4308_1 - python=3.11.9=h955ad1f_0 - python-dateutil=2.9.0post0=py311h06a4308_2 - pyzmq=25.1.2=py311h6a678d5_0 - readline=8.2=h5eee18b_0 - setuptools=72.1.0=py311h06a4308_0 - six=1.16.0=pyhd3eb1b0_1 - sqlite=3.45.3=h5eee18b_0 - stack_data=0.2.0=pyhd3eb1b0_0 - tk=8.6.14=h39e8969_0 - tornado=6.4.1=py311h5eee18b_0 - traitlets=5.14.3=py311h06a4308_0 - typing_extensions=4.11.0=py311h06a4308_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.43.0=py311h06a4308_0 - xz=5.4.6=h5eee18b_1 - zeromq=4.3.5=h6a678d5_0 - zlib=1.2.13=h5eee18b_1 - pip: - accelerate==0.33.0 - aiohappyeyeballs==2.4.0 - aiohttp==3.10.5 - aiosignal==1.3.1 - annotated-types==0.7.0 - anyio==4.4.0 - attrs==24.2.0 - certifi==2024.7.4 - charset-normalizer==3.3.2 - click==8.1.7 - datasets==2.21.0 - dill==0.3.8 - distro==1.9.0 - docstring-parser==0.16 - evaluate==0.4.2 - filelock==3.15.4 - fire==0.6.0 - frozenlist==1.4.1 - fsspec==2024.6.1 - fuzzywuzzy==0.18.0 - h11==0.14.0 - httpcore==1.0.5 - httpx==0.27.1 - huggingface-hub==0.24.5 - idna==3.7 - jieba==0.42.1 - jinja2==3.1.4 - jiter==0.5.0 - joblib==1.4.2 - jsonschema==4.23.0 - jsonschema-specifications==2023.12.1 - llmlingua==0.2.2 - markupsafe==2.1.5 - mistral-common==1.3.4 - mistral-inference==1.3.1 - mpmath==1.3.0 - multidict==6.0.5 - multiprocess==0.70.16 - networkx==3.3 - nltk==3.8.1 - numpy==1.26.4 - nvidia-cublas-cu12==12.1.3.1 - nvidia-cuda-cupti-cu12==12.1.105 - nvidia-cuda-nvrtc-cu12==12.1.105 - nvidia-cuda-runtime-cu12==12.1.105 - nvidia-cudnn-cu12==9.1.0.70 - nvidia-cufft-cu12==11.0.2.54 - nvidia-curand-cu12==10.3.2.106 - nvidia-cusolver-cu12==11.4.5.107 - nvidia-cusparse-cu12==12.1.0.106 - nvidia-nccl-cu12==2.20.5 - nvidia-nvjitlink-cu12==12.6.20 - nvidia-nvtx-cu12==12.1.105 - openai==1.42.0 - pandas==2.2.2 - psutil==6.0.0 - pyarrow==17.0.0 - pydantic==2.8.2 - pydantic-core==2.20.1 - pytz==2024.1 - pyyaml==6.0.2 - referencing==0.35.1 - regex==2024.7.24 - requests==2.32.3 - rouge==1.0.1 - rpds-py==0.20.0 - safetensors==0.4.4 - sentencepiece==0.2.0 - simple-parsing==0.1.5 - sniffio==1.3.1 - sympy==1.13.2 - termcolor==2.4.0 - tiktoken==0.7.0 - tokenizers==0.19.1 - torch==2.4.0 - tqdm==4.66.5 - transformers==4.44.0 - triton==3.0.0 - typing-extensions==4.12.2 - tzdata==2024.1 - urllib3==2.2.2 - xformers==0.0.27.post2 - xxhash==3.5.0 - yarl==1.9.4 prefix: /home/test ```

Reproduction Steps

from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

model = Transformer.from_folder("./models/mistral-7B-v0.1", device="cuda:7")
tokenizer = MistralTokenizer.from_file("./models/mistral-7B-v0.1/tokenizer.model").instruct_tokenizer.tokenizer

prompt = "What is the capital of germany? Answer:"
tokens = tokenizer.encode(prompt, bos=True, eos=False)
out_tokens, logprobs = generate([tokens], model, max_tokens=50, temperature=0)
result = tokenizer.decode(out_tokens[0])

Expected Behavior

I am getting the following error when trying to run above code:

ValueError: Attention bias and Query/Key/Value should be on the same device
  query.device: cuda:7
  attn_bias   : cuda:0

This seems related to https://github.com/facebookresearch/xformers/issues/1064, couldn't figure out why this happens yet...

Additional Context

Stack trace ``` File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/generate.py", line 82, in generate prelogits = model.forward( ^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 276, in forward h = self.forward_partial(input_ids, seqlens, cache=cache) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 258, in forward_partial h = layer(h, freqs_cis, cache_view) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 156, in forward r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 100, in forward output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 276, in memory_efficient_attention return _memory_efficient_attention( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 395, in _memory_efficient_attention return _memory_efficient_attention_forward( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 411, in _memory_efficient_attention_for ward inp.validate_inputs() File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/common.py", line 145, in validate_inputs raise ValueError( ValueError: Attention bias and Query/Key/Value should be on the same device query.device: cuda:7 attn_bias : cuda:0 ```

Suggested Solutions

No response

cornzz commented 3 months ago

Seems I found the issue, PR: https://github.com/mistralai/mistral-inference/pull/216