predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
2.03k stars 135 forks source link

Can't run Mistral quantized on T4 #417

Open emillykkejensen opened 4 months ago

emillykkejensen commented 4 months ago

System Info

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       Off |   00000001:00:00.0 Off |                  Off |
| N/A   28C    P0             24W /   70W |       0MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
docker run --runtime nvidia --gpus all --ipc=host -p 8080:80 \
    -v $PWD/data:/data \
    ghcr.io/predibase/lorax:latest \
    --model-id mistralai/Mistral-7B-v0.1 \
    --quantize bitsandbytes-nf4

Information

Tasks

Reproduction

I'm simply trying to run mistralai/Mistral-7B-v0.1 with 4-bit quantization on my T4 with bitsandbytes-nf4! However it errors with 'Mistral model requires flash attn v2'?

2024-04-16T14:50:32.809986Z  INFO download: lorax_launcher: Successfully downloaded weights.
2024-04-16T14:50:32.810173Z  INFO shard-manager: lorax_launcher: Starting shard rank=0
2024-04-16T14:50:38.132469Z  WARN lorax_launcher: flash_attn.py:48 Unable to use Flash Attention V2: GPU with CUDA capability 7 5 is not supported for Flash Attention V2

2024-04-16T14:50:38.267395Z ERROR lorax_launcher: server.py:271 Error when initializing model
Traceback (most recent call last):
  File "/opt/conda/bin/lorax-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 89, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 321, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 267, in serve_inner
    model = get_model(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 179, in get_model
    from lorax_server.models.flash_mistral import FlashMistral
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_mistral.py", line 10, in <module>
    from lorax_server.models.custom_modeling.flash_mistral_modeling import (
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 49, in <module>
    raise ImportError("Mistral model requires flash attn v2")
ImportError: Mistral model requires flash attn v2

2024-04-16T14:50:39.215837Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:

Traceback (most recent call last):

  File "/opt/conda/bin/lorax-server", line 8, in <module>
    sys.exit(app())

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 89, in serve
    server.serve(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 321, in serve
    asyncio.run(

  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 267, in serve_inner
    model = get_model(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 179, in get_model
    from lorax_server.models.flash_mistral import FlashMistral

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_mistral.py", line 10, in <module>
    from lorax_server.models.custom_modeling.flash_mistral_modeling import (

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 49, in <module>
    raise ImportError("Mistral model requires flash attn v2")

ImportError: Mistral model requires flash attn v2
 rank=0
2024-04-16T14:50:39.314620Z ERROR lorax_launcher: Shard 0 failed to start
2024-04-16T14:50:39.314636Z  INFO lorax_launcher: Shutting down shards

Expected behavior

The model to load and run!?

tgaddair commented 4 months ago

Hey @emillykkejensen, unfortunately our min supported architecture at the moment is Ampere due to the flash attention dependency. Please see system requirements here: https://github.com/predibase/lorax?tab=readme-ov-file#requirements

emillykkejensen commented 4 months ago

Fair enough. However, one could argue that the point of qlora among other things, is to serve on smaller (older and cheeper) GPU's that don't support ampere? Is there anything in the making, or?

tgaddair commented 4 months ago

Yes, we have plans to move our attention computation over to the FlashInfer project, which is working on support for Volta and Turning GPUs. So hopefully that will address the issue.

emillykkejensen commented 4 months ago

Sounds good 😊 I'm sure you are already aware, but in the off case your not, I can see that there is a fix in TGI? However it seems they simply fix it by loading the full model?

nethi commented 1 month ago

Is it fair to assume that this should now work given this PR https://github.com/predibase/lorax/issues/440 is merged ? With latest versions, I seem to be able to get past FA 2 errors but seems to run into different issue https://github.com/predibase/lorax/issues/535