huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
8.86k stars 1.04k forks source link

`mistralai/Mixtral-8x22B-Instruct-v0.1`: Successful warmup, crashes on inference #2083

Closed alexanderdicke-webcom closed 1 month ago

alexanderdicke-webcom commented 3 months ago

System Info

TGI Version: v2.0.4 Model: mistralai/Mixtral-8x22B-Instruct-v0.1 Hardware: 4x Nvidia H100 70GB HBM3 Deployment specificities: OpenShift

Information

Tasks

Reproduction

Running TGI with

The warmup is successful:

2024-06-17T14:22:01.460404Z  INFO text_generation_router: router/src/main.rs:354: Setting max batch total tokens to 42768

The model successfully handles relatively small requests. When sending larger requests (close to the maximal input length but still below) the prefill operation crashes:

2024-06-17T14:22:46.490042Z ERROR text_generation_launcher: Method Prefill encountered an error.
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 90, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 258, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/conda/lib/python3.10/site-packages/grpc_interceptor/server.py", line 165, in invoke_intercept_method
    return await self.intercept(
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/interceptor.py", line 21, in intercept
    return await response
  File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 82, in _unary_interceptor
    raise error
  File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 73, in _unary_interceptor
    return await behavior(request_or_iterator, context)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 144, in Prefill
    generations, next_batch, timings = self.model.generate_token(batch)
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py", line 960, in generate_token
    raise e
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py", line 957, in generate_token
    out, speculative_logits = self.forward(batch)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_mistral.py", line 523, in forward
    logits, speculative_logits = self.model.forward(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py", line 647, in forward
    hidden_states = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py", line 589, in forward
    hidden_states, residual = layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py", line 529, in forward
    moe_output = self.moe(normed_attn_res_output)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py", line 367, in forward
    out = fused_moe(
  File "/opt/conda/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 430, in fused_moe
    intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 758.00 MiB. GPU 

I experimented with setting MAX_BATCH_SIZE=1 to be sure that that the number of tokens in the batch is below the max batch total tokens calculated by TGI. However, the error still occurs.

Expected behavior

There is no OOM since the number of max batch total tokens is correct.

LysandreJik commented 3 months ago

Thanks for the report @alexanderdicke-webcom!

After discussing it briefly with Olivier it seems linked to the torch allocator acting up. You're right that after warming up, it should be able to handle the sequences you pass to the model, so we'll need to take a look.

cc @OlivierDehaene

paulcx commented 3 months ago

I got the similar oom issues (34B llama on 2*A6000) which happened in sha-74b0231 but works in sha-eade737.

github-actions[bot] commented 2 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.