huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
9.15k stars 1.08k forks source link

[Volta] [No flash attention] Llama 3.1 8B Instruct failed to start - "< not supported between instances of 'NoneType' and 'int'" #2440

Open ladi-pomsar opened 3 months ago

ladi-pomsar commented 3 months ago

System Info

Hi everyone, when trying to update from Llama 3 8B Instruct to Llama 3.1 8B Instruct, I noticed a crash:

Args {
    model_id: "meta-llama/Meta-Llama-3.1-8B-Instruct",
    revision: None,
    validation_workers: 2,
    sharded: Some(
        false,
    ),
    num_shard: None,
    quantize: None,
    speculate: None,
    dtype: None,
    trust_remote_code: false,
    max_concurrent_requests: 128,
    max_best_of: 2,
    max_stop_sequences: 4,
    max_top_n_tokens: 5,
    max_input_tokens: Some(
        1500,
    ),
    max_input_length: None,
    max_total_tokens: None,
    waiting_served_ratio: 0.3,
    max_batch_prefill_tokens: None,
    max_batch_total_tokens: None,
    max_waiting_tokens: 20,
    max_batch_size: None,
    cuda_graphs: None,
    hostname: "llm2.internal",
    port: 80,
    shard_uds_path: "/tmp/text-generation-server",
    master_addr: "localhost",
    master_port: 29500,
    huggingface_hub_cache: None,
    weights_cache_override: None,
    disable_custom_kernels: false,
    cuda_memory_fraction: 1.0,
    rope_scaling: None,
    rope_factor: None,
    json_output: false,
    otlp_endpoint: None,
    otlp_service_name: "text-generation-inference.router",
    cors_allow_origin: [],
    api_key: None,
    watermark_gamma: None,
    watermark_delta: None,
    ngrok: false,
    ngrok_authtoken: None,
    ngrok_edge: None,
    tokenizer_config_path: None,
    disable_grammar_support: false,
    env: false,
    max_client_batch_size: 4,
    lora_adapters: None,
    usage_stats: On,
}
2024-08-21T12:39:33.092070Z  INFO hf_hub: Token file not found "/data/token"    
2024-08-21T12:39:33.092172Z  INFO text_generation_launcher: Default `max_total_tokens` to 4096
2024-08-21T12:39:33.092177Z  INFO text_generation_launcher: Default `max_batch_prefill_tokens` to 1550
2024-08-21T12:39:33.092179Z  INFO text_generation_launcher: Using default cuda graphs [1, 2, 4, 8, 16, 32]
2024-08-21T12:39:33.092301Z  INFO download: text_generation_launcher: Starting check and download process for meta-llama/Meta-Llama-3.1-8B-Instruct
2024-08-21T12:39:36.657328Z  INFO text_generation_launcher: Files are already present on the host. Skipping download.
2024-08-21T12:39:37.298067Z  INFO download: text_generation_launcher: Successfully downloaded weights for meta-llama/Meta-Llama-3.1-8B-Instruct
2024-08-21T12:39:37.298432Z  INFO shard-manager: text_generation_launcher: Starting shard rank=0
2024-08-21T12:39:40.778943Z  INFO text_generation_launcher: Using Attention = False
2024-08-21T12:39:40.778983Z  INFO text_generation_launcher: Using Attention = paged
2024-08-21T12:39:40.780265Z  WARN text_generation_launcher: Could not import Flash Attention enabled models: `USE_FLASH_ATTENTION` is false.
2024-08-21T12:39:47.313495Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
2024-08-21T12:39:49.797482Z  INFO text_generation_launcher: Server started at unix:///tmp/text-generation-server-0
2024-08-21T12:39:49.816333Z  INFO shard-manager: text_generation_launcher: Shard ready in 12.516327605s rank=0
2024-08-21T12:39:49.912418Z  INFO text_generation_launcher: Starting Webserver
2024-08-21T12:39:49.959548Z  INFO text_generation_router_v3: backends/v3/src/lib.rs:90: Warming up model
2024-08-21T12:39:49.969069Z ERROR text_generation_launcher: Method Warmup encountered an error.
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 109, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 274, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/conda/lib/python3.10/site-packages/grpc_interceptor/server.py", line 165, in invoke_intercept_method
    return await self.intercept(
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/interceptor.py", line 21, in intercept
    return await response
  File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 120, in _unary_interceptor
    raise error
  File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 111, in _unary_interceptor
    return await behavior(request_or_iterator, context)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 120, in Warmup
    batch = self.model.batch_type.from_pb(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/causal_lm.py", line 119, in from_pb
    tokenized_inputs = tokenizer(
  File "/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 3073, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 3160, in _call_one
    return self.batch_encode_plus(
  File "/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 3347, in batch_encode_plus
    padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  File "/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2976, in _get_padding_truncation_strategies
    if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0):
TypeError: '<' not supported between instances of 'NoneType' and 'int'
2024-08-21T12:39:49.969363Z ERROR warmup{max_input_length=1500 max_prefill_tokens=1550 max_total_tokens=4096 max_batch_size=None}:warmup: text_generation_router_v3::client: backends/v3/src/client/mod.rs:54: Server error: '<' not supported between instances of 'NoneType' and 'int'
Error: Backend(Warmup(Generation("'<' not supported between instances of 'NoneType' and 'int'")))
2024-08-21T12:39:50.041253Z ERROR text_generation_launcher: Webserver Crashed
2024-08-21T12:39:50.041274Z  INFO text_generation_launcher: Shutting down shards
2024-08-21T12:39:50.116691Z  INFO shard-manager: text_generation_launcher: Terminating shard rank=0
2024-08-21T12:39:50.116733Z  INFO shard-manager: text_generation_launcher: Waiting for shard to gracefully shutdown rank=0
2024-08-21T12:39:51.318250Z  INFO shard-manager: text_generation_launcher: shard terminated rank=0

Deployment mode: Docker compose Container settings:

  llm2:
    container_name: llm2
    hostname: llm2
    profiles:
      - common
    image: ghcr.io/huggingface/text-generation-inference:latest
    command: --model-id meta-llama/Meta-Llama-3.1-8B-Instruct --sharded false --max-input-tokens 1500
    volumes:
      - /home/llm_data:/data
    ports:
      - "3000:80"
    environment:
      - HF_HUB_ENABLE_HF_TRANSFER="false"
      - USE_FLASH_ATTENTION=False
      - HF_TOKEN=
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              device_ids: ["4","5"]
              capabilities: [gpu]
    networks:
      - container-network

OS: Ubuntu 22.04.4 LTS Rust version: N/A Container version: sha256:b49037cef8d0c61ec022d4d7c5baad22357e34bce7970148a457a11f8f8d7e36 Model being used: meta-llama/Meta-Llama-3.1-8B-Instruct GPUs: 2x Volta V100 - hence disabled Flash attention

Information

Tasks

Reproduction

  1. Run the container in abovementioned configuration

Expected behavior

Llama 3.1 8B Instruct should work

ladi-pomsar commented 2 months ago

Doesn't seem to be case with flash attention-enabled ADA generation GPU, thus seems to be specific to lack of flash attention.

ladi-pomsar commented 1 month ago

For anyone wondering about this, this is due to the fact that pad_token is not present in Llama's tokenizer_config.json. Something as simple as adding "pad_token": "<|eot_id|>" to the end of the json works.

For some reason (code branching?) this doesn't bother FA enabled GPUs/is fixed within that branch, but bother those that need to disable FA.

dvrogozh commented 5 days ago

I see the same issue running meta-llama/Llama-3.1-8B-Instruct and meta-llama/Llama-3.2-3B-Instruct on Intel GPU with 2.4.0-intel-xpu container:

$ docker run --rm --privileged --cap-add=sys_nice -e HF_TOKEN=xxx \
    --device=/dev/dri     --ipc=host --shm-size 1g --net host -v /home/dvrogozh/data:/data \
    ghcr.io/huggingface/text-generation-inference:2.4.0-intel-xpu \
    --model-id meta-llama/Llama-3.2-3B-Instruct --cuda-graphs 0 --port 8080
...
  File "/opt/conda/lib/python3.11/site-packages/text_generation_server/server.py", line 132, in Warmup
    batch = self.model.batch_type.from_pb(
  File "/opt/conda/lib/python3.11/site-packages/text_generation_server/models/causal_lm.py", line 120, in from_pb
    tokenized_inputs = tokenizer(
  File "/opt/conda/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 3016, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/opt/conda/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 3104, in _call_one
    return self.batch_encode_plus(
  File "/opt/conda/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 3297, in batch_encode_plus
    padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
  File "/opt/conda/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 2917, in _get_padding_truncation_strategies
    if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0):
TypeError: '<' not supported between instances of 'NoneType' and 'int'
2024-11-20T21:07:58.767307Z ERROR warmup{max_input_length=4095 max_prefill_tokens=4145 max_total_tokens=4096 max_batch_size=None}:warmup: text_generation_router_v3::client: backends/v3/src/client/mod.rs:45: Server error: '<' not supported between instances of 'NoneType' and 'int'
Error: Backend(Warmup(Generation("'<' not supported between instances of 'NoneType' and 'int'")))

I believe these XPU containers should have attention, but it's different attention implementation vs. CUDA. So, it's still might be branch difference why XPU steps into that.

With XPU this also is a regression because I see this working fine with 2.3.0-intel-xpu container:

$ docker run --rm --privileged --cap-add=sys_nice -e HF_TOKEN=xxx \
    --device=/dev/dri     --ipc=host --shm-size 1g --net host -v /home/dvrogozh/data:/data \
    ghcr.io/huggingface/text-generation-inference:2.3.0-intel-xpu \
    --model-id meta-llama/Llama-3.2-3B-Instruct --cuda-graphs 0 --port 8080

@sywangyi

dvrogozh commented 5 days ago

Further, running again on Intel GPU, but with stock PyTorch this time. This variant definitely does not have attention (setup differs from docker xpu runs). There is behavior change coming after this PR in Transformers:

Without above commit original issue can be reproduced:

$ cd /path/to/transformers && git reset --hard 187439c3fa139b2102a874483e9f8f0cfa8e5557~1 && pip install -e .
$ text-generation-launcher --model-id meta-llama/Llama-3.2-3B-Instruct --cuda-graphs 0 --port 8080
...
2024-11-20T22:40:08.340828Z  INFO text_generation_router_v3: backends/v3/src/lib.rs:125: Warming up model
2024-11-20T22:40:08.343788Z ERROR text_generation_launcher: Method Warmup encountered an error.
...
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/tokenization_utils_base.py", line 2922, in _get_padding_truncation_strategies
    if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0):
TypeError: '<' not supported between instances of 'NoneType' and 'int'

While after the commit I pointed to, behavior changes and TGI fails earlier on model initialization:

$ cd /path/to/transformers && git reset --hard 187439c3fa139b2102a874483e9f8f0cfa8e5557 && pip install -e .
$ text-generation-launcher --model-id meta-llama/Llama-3.2-3B-Instruct --cuda-graphs 0 --port 8080
...
2024-11-20T22:41:02.635005Z ERROR text_generation_launcher: Error when initializing model
...
> File "/home/dvrogozh/git/huggingface/text-generation-inference/server/text_generation_server/server.py", line 268, in serve_inner
    model = get_model_with_lora_adapters(
  File "/home/dvrogozh/git/huggingface/text-generation-inference/server/text_generation_server/models/__init__.py", line 1336, in get_model_with_lora_adapters
    model = get_model(
  File "/home/dvrogozh/git/huggingface/text-generation-inference/server/text_generation_server/models/__init__.py", line 878, in get_model
    return CausalLM.fallback(
  File "/home/dvrogozh/git/huggingface/text-generation-inference/server/text_generation_server/models/causal_lm.py", line 634, in fallback
    tokenizer.pad_token_id = model.config.eos_token_id
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/tokenization_utils_base.py", line 1076, in __setattr__
    raise ValueError(f"Cannot set a non-string value as the {key}")
ValueError: Cannot set a non-string value as the pad_token

Failure in the second case is here: https://github.com/dvrogozh/transformers/blob/187439c3fa139b2102a874483e9f8f0cfa8e5557/src/transformers/tokenization_utils_base.py#L1076

Printing also values which we get I have that __setattr__ was called as:

__setattr__('pad_token', ['<|end_of_text|>', '<|eom_id|>', '<|eot_id|>'])

I.e. we have a list of tokens instead of a single string value. __setattr__ does not seem account for this case. I wonder whether this can also give us a clue on why we stepped into original issue? Maybe path with attention somehow correctly handles a list of tokens? @zucchini-nlp as an author of https://github.com/huggingface/transformers/pull/34461, and @ArthurZucker : maybe you can comment here on the behavior and suggest further debug steps?

sywangyi commented 5 days ago

https://github.com/huggingface/text-generation-inference/pull/2702 (has been merged) means to fix following issue 2024-11-21T00:40:07.383973Z WARN text_generation_launcher: Could not import Flash Attention enabled models: No module named 'triton'
which will lead to xpu to go into causal_lm.py instead of flash_causal_lm.py

you should use image ghcr.io/huggingface/text-generation-inference:latest-intel-xpu to avoid this issue. meta-llama/Llama-3.2-3B-Instruct should work in latest tgi xpu image

dvrogozh commented 5 days ago

@sywangyi : thank you for pointing this out. I missed this warning. Indeed ghcr.io/huggingface/text-generation-inference:latest-intel-xpu works for me. This also correlates with @ladi-pomsar assumption that this issue is specific to the cases when attention is not enabled. Do you have ideas what in attention path makes things work?

dvrogozh commented 4 days ago

Basically, here is a simplified script to reproduce the issue. That's what TGI is doing around https://github.com/huggingface/text-generation-inference/blob/07bed530f7eaf2419ed0e755e0f24d7afd814a46/server/text_generation_server/models/causal_lm.py#L634

Script:

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-3B-Instruct')
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.2-3B-Instruct')

print(f">>> tokenizer.pad_token_id={tokenizer.pad_token_id}")
print(f">>> model.config.pad_token_id={model.config.pad_token_id}")
print(f">>> model.config.eos_token_id={model.config.eos_token_id}")

tokenizer.pad_token_id = model.config.eos_token_id

Output:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.03it/s]
>>> tokenizer.pad_token_id=None
>>> model.config.pad_token_id=None
>>> model.config.eos_token_id=[128001, 128008, 128009]
Traceback (most recent call last):
  File "/home/dvrogozh/tmp/e.py", line 11, in <module>
    tokenizer.pad_token_id = model.config.eos_token_id
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/tokenization_utils_base.py", line 1076, in __setattr__
    raise ValueError(f"Cannot set a non-string value as the {key}")
ValueError: Cannot set a non-string value as the pad_token

Unfortunately, my knowledge of Transformers is not enough to say what's wrong and where.

dvrogozh commented 4 days ago

Hm. I like this place in the attention path (which works): https://github.com/huggingface/text-generation-inference/blob/ab7ccf5bc3c84e07d0faf0d950421fcdc29743b5/server/text_generation_server/models/flash_causal_lm.py#L1261-L1263

This was introduced by the following PR:

@Narsil : do you recall details on the tokenizer._eos_token_ids hack? Can you suggest how to resolve issue we observe on non-attention path in TGI and which can be reproduced by simple script with Transformers (see my last comment above)?

dvrogozh commented 4 days ago

I have filed issue/question on Transformers side:

zucchini-nlp commented 4 days ago

Basically, here is a simplified script to reproduce the issue. That's what TGI is doing around

Cool, thanks for a reproducer. I will check it out and will be commenting under the transformers issue

ladi-pomsar commented 4 days ago

@sywangyi : thank you for pointing this out. I missed this warning. Indeed ghcr.io/huggingface/text-generation-inference:latest-intel-xpu works for me. This also correlates with @ladi-pomsar assumption that this issue is specific to the cases when attention is not enabled. Do you have ideas what in attention path makes things work?

I didn't post follow-up, but if you disable flash attention on newer NVIDIA generations through the TGI env variable USE_FLASH_ATTENTION=False, you are able to reproduce it there as well.

dvrogozh commented 3 days ago

@zucchini-nlp : thank you for feedback. I've posted #2774 with the proposed fix.

if you disable flash attention on newer NVIDIA generations through the TGI env variable USE_FLASH_ATTENTION=False, you are able to reproduce it there as well.

@zucchini-nlp : indeed. After #2774 this case on NVidia start to work as well.