huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
9.02k stars 1.06k forks source link

CUDA OOM on first request with bitsandbytes quantization #1166

Closed RDearnaley closed 10 months ago

RDearnaley commented 1 year ago

System Info

TGI version: 1.1.0 LLM: Mistral 7B Instruct v0.1 Virtual hardware: Kubernetes on Azure

resources:
  limits:
    cpu: 7.
    memory: 60G
    ephemeral-storage: 140G
    nvidia.com/mig-2g.20gb: 1

nvidia.com/mig-2g.20gb = A100 80GB GPU sliced down to to a MIG-2g slice (2/7 processing power, 2/8 = ~20GB GPU RAM) Operating system: Ubuntu Linux Kubernetes version: 1.26.6 Node size: Standard_NC24ads_A100_v4 Node image version: AKSUbuntu-2204gen2containerd-202310.04.0 Kernel version: 5.15.0-1049-azure Container runtime version: containerd://1.7.5-1 Manually created with command: az aks nodepool add --mode System --node-osdisk-size 512 --node-osdisk-type Managed --name --resource-group --cluster-name **** --node-count 1 --node-vm-size Standard_NC24ads_A100_v4 --gpu-instance-profile MIG2g

Information

Tasks

Reproduction

Using Kubernetes on Ubuntu:

  1. With 1.1.0 (current) TGI version: image:

    repository: ghcr.io/huggingface/text-generation-inference
    pullPolicy: Always
    tag: "1.1.0"
  2. Environment variables set:

            - name: "MODEL_ID"
              value: "mistralai/Mistral-7B-Instruct-v0.1"
            - name: "QUANTIZE"
              value: "bitsandbytes"
            - name: "SHARDED"
              value: "false"
            - name: "PORT"
              value: "80"
            - name: "MAX_TOTAL_TOKENS"
              value: "5632"
            - name: "MAX_INPUT_LENGTH"
              value: "3584"

    Mistral 7B (the settings make use of its sliding window attention, but it also occurs with MAX_TOTAL_TOKENS=4096, MAX_INPUT_LENGTH=2048, not making use of its sliding window. (I'm using bitandbytes quentization since setting the quantization to eetq produced errors saying it needed to be installed.)

  3. Starting the model up produces normal-looking logs:

    {"timestamp":"2023-10-17T18:33:19.394561Z","level":"INFO","fields":{"message":"Args { model_id: \"mistralai/Mistral-7B-Instruct-v0.1\", revision: None, validation_workers: 2, sharded: Some(false), num_shard: None, quantize: Some(Bitsandbytes), dtype: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_top_n_tokens: 5, max_input_length: 3584, max_total_tokens: 5632, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4096, max_batch_total_tokens: None, max_waiting_tokens: 20, hostname: \"ydc-unc-sm-llm-server-65499848fb-52ts2\", port: 80, shard_uds_path: \"/tmp/text-generation-server\", master_addr: \"localhost\", master_port: 29500, huggingface_hub_cache: Some(\"/data\"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, rope_scaling: None, rope_factor: None, json_output: true, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false }"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:33:19.394669Z","level":"INFO","fields":{"message":"Starting download process."},"target":"text_generation_launcher","span":{"name":"download"},"spans":[{"name":"download"}]}
    {"timestamp":"2023-10-17T18:33:22.037246Z","level":"WARN","fields":{"message":"No safetensors weights found for model mistralai/Mistral-7B-Instruct-v0.1 at revision None. Downloading PyTorch weights.\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:33:22.340536Z","level":"INFO","fields":{"message":"Download file: pytorch_model-00001-of-00002.bin\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:33:43.307713Z","level":"INFO","fields":{"message":"Downloaded /data/models--mistralai--Mistral-7B-Instruct-v0.1/snapshots/7ad5799710574ba1c1d953eba3077af582f3a773/pytorch_model-00001-of-00002.bin in 0:00:20.\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:33:43.307800Z","level":"INFO","fields":{"message":"Download: [1/2] -- ETA: 0:00:20\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:33:43.308102Z","level":"INFO","fields":{"message":"Download file: pytorch_model-00002-of-00002.bin\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:33:59.351971Z","level":"INFO","fields":{"message":"Downloaded /data/models--mistralai--Mistral-7B-Instruct-v0.1/snapshots/7ad5799710574ba1c1d953eba3077af582f3a773/pytorch_model-00002-of-00002.bin in 0:00:16.\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:33:59.352053Z","level":"INFO","fields":{"message":"Download: [2/2] -- ETA: 0\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:33:59.352130Z","level":"WARN","fields":{"message":"No safetensors weights found for model mistralai/Mistral-7B-Instruct-v0.1 at revision None. Converting PyTorch weights to safetensors.\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:36:23.495078Z","level":"INFO","fields":{"message":"Convert: [1/2] -- Took: 0:02:23.725971\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:36:50.433431Z","level":"INFO","fields":{"message":"Convert: [2/2] -- Took: 0:00:26.938025\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:36:50.764997Z","level":"INFO","fields":{"message":"Successfully downloaded weights."},"target":"text_generation_launcher","span":{"name":"download"},"spans":[{"name":"download"}]}
    {"timestamp":"2023-10-17T18:36:50.765254Z","level":"INFO","fields":{"message":"Starting shard"},"target":"text_generation_launcher","span":{"rank":0,"name":"shard-manager"},"spans":[{"rank":0,"name":"shard-manager"}]}
    {"timestamp":"2023-10-17T18:36:55.371793Z","level":"WARN","fields":{"message":"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:37:00.776676Z","level":"INFO","fields":{"message":"Waiting for shard to be ready..."},"target":"text_generation_launcher","span":{"rank":0,"name":"shard-manager"},"spans":[{"rank":0,"name":"shard-manager"}]}
    {"timestamp":"2023-10-17T18:37:04.330684Z","level":"INFO","fields":{"message":"Server started at unix:///tmp/text-generation-server-0\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:37:04.382726Z","level":"INFO","fields":{"message":"Shard ready in 13.616763293s"},"target":"text_generation_launcher","span":{"rank":0,"name":"shard-manager"},"spans":[{"rank":0,"name":"shard-manager"}]}
    {"timestamp":"2023-10-17T18:37:04.482828Z","level":"INFO","fields":{"message":"Starting Webserver"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T18:37:04.652254Z","level":"WARN","message":"`--revision` is not set","target":"text_generation_router","filename":"router/src/main.rs","line_number":349}
    {"timestamp":"2023-10-17T18:37:04.652283Z","level":"WARN","message":"We strongly advise to set it to a known supported commit.","target":"text_generation_router","filename":"router/src/main.rs","line_number":350}
    {"timestamp":"2023-10-17T18:37:04.720079Z","level":"INFO","message":"Serving revision 7ad5799710574ba1c1d953eba3077af582f3a773 of model mistralai/Mistral-7B-Instruct-v0.1","target":"text_generation_router","filename":"router/src/main.rs","line_number":371}
    {"timestamp":"2023-10-17T18:37:04.724411Z","level":"INFO","message":"Warming up model","target":"text_generation_router","filename":"router/src/main.rs","line_number":213}
    {"timestamp":"2023-10-17T18:37:07.033684Z","level":"INFO","message":"Setting max batch total tokens to 62624","target":"text_generation_router","filename":"router/src/main.rs","line_number":246}
    {"timestamp":"2023-10-17T18:37:07.033711Z","level":"INFO","message":"Connected","target":"text_generation_router","filename":"router/src/main.rs","line_number":247}
    {"timestamp":"2023-10-17T18:37:07.033717Z","level":"WARN","message":"Invalid hostname, defaulting to 0.0.0.0","target":"text_generation_router","filename":"router/src/main.rs","line_number":252}

    which have no obvious significant issues in them. Note the line Setting max batch total tokens to 62624.

  4. The first query sent to the server reliably produces a CUDA OOM error:

    {"timestamp":"2023-10-17T19:33:40.617102Z","level":"ERROR","fields":{"message":"Method Prefill encountered an error.\nTraceback (most recent call last):\n  File \"/opt/conda/bin/text-generation-server\", line 8, in <module>\n    sys.exit(app())\n  File \"/opt/conda/lib/python3.9/site-packages/typer/main.py\", line 311, in __call__\n    return get_command(self)(*args, **kwargs)\n  File \"/opt/conda/lib/python3.9/site-packages/click/core.py\", line 1157, in __call__\n    return self.main(*args, **kwargs)\n  File \"/opt/conda/lib/python3.9/site-packages/typer/core.py\", line 778, in main\n    return _main(\n  File \"/opt/conda/lib/python3.9/site-packages/typer/core.py\", line 216, in _main\n    rv = self.invoke(ctx)\n  File \"/opt/conda/lib/python3.9/site-packages/click/core.py\", line 1688, in invoke\n    return _process_result(sub_ctx.command.invoke(sub_ctx))\n  File \"/opt/conda/lib/python3.9/site-packages/click/core.py\", line 1434, in invoke\n    return ctx.invoke(self.callback, **ctx.params)\n  File \"/opt/conda/lib/python3.9/site-packages/click/core.py\", line 783, in invoke\n    return __callback(*args, **kwargs)\n  File \"/opt/conda/lib/python3.9/site-packages/typer/main.py\", line 683, in wrapper\n    return callback(**use_params)  # type: ignore\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py\", line 83, in serve\n    server.serve(\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py\", line 207, in serve\n    asyncio.run(\n  File \"/opt/conda/lib/python3.9/asyncio/runners.py\", line 44, in run\n    return loop.run_until_complete(main)\n  File \"/opt/conda/lib/python3.9/asyncio/base_events.py\", line 634, in run_until_complete\n    self.run_forever()\n  File \"/opt/conda/lib/python3.9/asyncio/base_events.py\", line 601, in run_forever\n    self._run_once()\n  File \"/opt/conda/lib/python3.9/asyncio/base_events.py\", line 1905, in _run_once\n    handle._run()\n  File \"/opt/conda/lib/python3.9/asyncio/events.py\", line 80, in _run\n    self._context.run(self._callback, *self._args)\n  File \"/opt/conda/lib/python3.9/site-packages/grpc_interceptor/server.py\", line 159, in invoke_intercept_method\n    return await self.intercept(\n> File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/interceptor.py\", line 21, in intercept\n    return await response\n  File \"/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py\", line 82, in _unary_interceptor\n    raise error\n  File \"/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py\", line 73, in _unary_interceptor\n    return await behavior(request_or_iterator, context)\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py\", line 94, in Prefill\n    generations, next_batch = self.model.generate_token(batch)\n  File \"/opt/conda/lib/python3.9/contextlib.py\", line 79, in inner\n    return func(*args, **kwds)\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/models/flash_causal_lm.py\", line 753, in generate_token\n    raise e\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/models/flash_causal_lm.py\", line 750, in generate_token\n    out = self.forward(batch)\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/models/flash_mistral.py\", line 343, in forward\n    logits = self.model.forward(\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py\", line 518, in forward\n    hidden_states = self.model(\n  File \"/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py\", line 463, in forward\n    hidden_states, residual = layer(\n  File \"/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py\", line 406, in forward\n    mlp_output = self.mlp(normed_attn_res_output)\n  File \"/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py\", line 348, in forward\n    gate_up_states = self.gate_up_proj(hidden_states)\n  File \"/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n    return forward_call(*args, **kwargs)\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/layers.py\", line 351, in forward\n    return self.linear.forward(x)\n  File \"/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/layers.py\", line 219, in forward\n    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)\n  File \"/opt/conda/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py\", line 563, in matmul\n    return MatMul8bitLt.apply(A, B, out, bias, state)\n  File \"/opt/conda/lib/python3.9/site-packages/torch/autograd/function.py\", line 506, in apply\n    return super().apply(*args, **kwargs)  # type: ignore[misc]\n  File \"/opt/conda/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py\", line 386, in forward\n    state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)\ntorch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 446.00 MiB (GPU 0; 19.50 GiB total capacity; 17.18 GiB already allocated; 42.94 MiB free; 19.50 GiB allowed; 18.20 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n"},"target":"text_generation_launcher"}
    {"timestamp":"2023-10-17T19:33:40.618526Z","level":"ERROR","message":"Server error: CUDA out of memory. Tried to allocate 446.00 MiB (GPU 0; 19.50 GiB total capacity; 17.18 GiB already allocated; 42.94 MiB free; 19.50 GiB allowed; 18.20 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF","target":"text_generation_client","filename":"router/client/src/lib.rs","line_number":33,"span":{"id":0,"size":1,"name":"prefill"},"spans":[{"batch_size":1,"name":"batch"},{"name":"prefill"},{"id":0,"size":1,"name":"prefill"},{"id":0,"size":1,"name":"prefill"}]}
    {"timestamp":"2023-10-17T19:33:40.619254Z","level":"ERROR","message":"Request failed during generation: Server error: CUDA out of memory. Tried to allocate 446.00 MiB (GPU 0; 19.50 GiB total capacity; 17.18 GiB already allocated; 42.94 MiB free; 19.50 GiB allowed; 18.20 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF","target":"text_generation_router::infer","filename":"router/src/infer.rs","line_number":589,"span":{"name":"send_error"},"spans":[{"default_return_full_text":"true","name":"compat_generate"},{"parameters":"GenerateParameters { best_of: None, temperature: Some(0.7), repetition_penalty: Some(1.1), top_k: Some(100), top_p: Some(0.95), typical_p: Some(0.95), do_sample: false, max_new_tokens: 2048, return_full_text: Some(false), stop: [], truncate: None, watermark: false, details: true, decoder_input_details: false, seed: None, top_n_tokens: None }","name":"generate_stream"},{"name":"async_stream"},{"request":"GenerateRequest { inputs: \"[### LARGE PROMPT ELIDED ###]\", parameters: GenerateParameters { best_of: None, temperature: Some(0.7), repetition_penalty: Some(1.1), top_k: Some(100), top_p: Some(0.95), typical_p: Some(0.95), do_sample: false, max_new_tokens: 2048, return_full_text: Some(false), stop: [], truncate: None, watermark: false, details: true, decoder_input_details: false, seed: None, top_n_tokens: None } }","name":"generate_stream"},{"name":"infer"},{"name":"send_error"}]}
  5. Unsuccessful remediations attempted: setting MAX_BATCH_TOTAL_TOKENS to a value less than 62624 doesn't work since this is a Flash Attention model and my value gets overridden by the inferred value (which presumably is too high). Setting CUDA_MEMORY_FRACTION doesn't work since, while it reduces the inferred max batch total tokens it's also used to torch.cuda.set_per_process_memory_fraction(…) which causes CUDA to send me an OOM when I hit a lower amount of memory used, so it lowers the memory ceiling at the same rate as it reduces the memory overallocation. I briefly experimented with setting:

            - name: "PYTORCH_CUDA_ALLOC_CONF"
              value: "max_split_size_mb:<various values>"

    but was unable to find a value that avoided the problem. In any case, hitting memory fragmentation seems unlikely as soon as the first query, unless it happened while downloading the model and converting it to safetensors.

  6. Time-consuming possibly-relevant things I haven't yet tried: eliminating bitandbytes quantization from the setup, preconverting the Mistral 7B model to safetensors and putting that up on Hugging Face for my server replicas to download as safetensors so they don't have to do the conversion after download, an exhausive search of all possible max_split_size_mb values.

  7. Presumably the inferred max batch total tokens calculation is wrong for this new model with these settings (unless we have a fast memory leak), but if so I don't understand it well enough to find the error and fix it. If the MAX_BATCH_TOTAL_TOKENS environment variable wasn't overwritten by the inferred value for Flash Attention models, or if we had a variable like CUDA_MEMORY_FRACTION but that only applied to the inferred max batch total tokens calculation and wasn't also used to set torch.cuda.set_per_process_memory_fraction(…), then I could manually tweak it to compensate for this. I could hack the server Python code to implement either of these things, but then I'd need to build my own Docker image with these hacks rather than relying on the official one.

Expected behavior

Server can perform inference without an immediate CUDA OOM on the first query (or indeed any early query).

RDearnaley commented 1 year ago

Confirmed that the issue goes away if I turn off bitsandbytes quantization.

github-actions[bot] commented 10 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.