huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
8.6k stars 990 forks source link

Unable to load quantized commandrplus-medusa on H100 #1991

Closed sdadas closed 1 month ago

sdadas commented 2 months ago

System Info

Running the official TGI 2.0.4 docker image. Trying to load a quantized version of text-generation-inference/commandrplus-medusa model on Nvidia H100 GPUs. From what I have been able to learn, neither EETQ nor bitsandbytes quantization support Hopper GPUs yet, so the only option is to use FP8, which should be supported.

Information

Tasks

Reproduction

When running the image with the following command:

docker run --gpus '"device=0,1"' --shm-size 1g -p 8080:80 -v /mnt/data:/data -e HUGGING_FACE_HUB_TOKEN="my_token" -d ghcr.io/huggingface/text-generation-inference:2.0.4 --model-id text-generation-inference/commandrplus-medusa --speculate 3 --num-shard 2 --quantize fp8 

the model crashes during the warmup phase with the following exception:

2024-06-01T14:23:31.232062Z ERROR text_generation_launcher: Method Warmup encountered an error.
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 90, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 257, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/conda/lib/python3.10/site-packages/grpc_interceptor/server.py", line 165, in invoke_intercept_method
    return await self.intercept(
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/interceptor.py", line 21, in intercept
    return await response
  File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 82, in _unary_interceptor
    raise error
  File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 73, in _unary_interceptor
    return await behavior(request_or_iterator, context)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 116, in Warmup
    max_supported_total_tokens = self.model.warmup(batch)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py", line 879, in warmup
    self.cuda_graph_warmup(bs, max_s, max_bt)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py", line 745, in cuda_graph_warmup
    self.model.forward(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_cohere_modeling.py", line 525, in forward
    logits, speculative_logits = self.lm_head(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/speculative.py", line 48, in forward
    return self.speculator(input)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/medusa.py", line 180, in forward
    logits = self.lm_head(stacked_x)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/tensor_parallel.py", line 75, in forward
    output = super().forward(input)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/tensor_parallel.py", line 13, in forward
    return self.linear.forward(x)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/fp8.py", line 35, in forward
    output, _ = torch._scaled_mm(
RuntimeError: mat1 must be a matrix

The issue only applies to the fp8 quantized model, fp16 version loads correctly.

Expected behavior

Expect the model to load correctly,

github-actions[bot] commented 1 month ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.