huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
9.06k stars 1.07k forks source link

Client validation error when server generating <unk> token #440

Closed edwardzjl closed 1 year ago

edwardzjl commented 1 year ago

System Info

text-generation-inference version: v0.8.2 text-generation version (python client): 0.6.0 gpu: nvidia A100 40G

text-generation-launcher env:

> text-generation-launcher --env
2023-06-12T09:13:08.093676Z  INFO text_generation_launcher: Runtime environment:
Target: x86_64-unknown-linux-gnu
Cargo version: 1.69.0
Commit sha: 5fde8d99919f3c81e1fd414aa11d2148680baea6
Docker label: sha-5fde8d9
nvidia-smi:
Mon Jun 12 09:13:07 2023       
   +---------------------------------------------------------------------------------------+
   | NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
   |-----------------------------------------+----------------------+----------------------+
   | GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
   | Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
   |                                         |                      |               MIG M. |
   |=========================================+======================+======================|
   |   0  NVIDIA A100-PCIE-40GB           Off| 00000000:00:06.0 Off |                    0 |
   | N/A   26C    P0               38W / 250W|  18937MiB / 40960MiB |      0%      Default |
   |                                         |                      |             Disabled |
   +-----------------------------------------+----------------------+----------------------+

   +---------------------------------------------------------------------------------------+
   | Processes:                                                                            |
   |  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
   |        ID   ID                                                             Usage      |
   |=======================================================================================|
   +---------------------------------------------------------------------------------------+
2023-06-12T09:13:08.093834Z  INFO text_generation_launcher: Args { model_id: "bigscience/bloom-560m", revision: None, sharded: None, num_shard: None, quantize: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_input_length: 1000, max_total_tokens: 1512, max_batch_size: None, waiting_served_ratio: 1.2, max_batch_total_tokens: 32000, max_waiting_tokens: 20, port: 8080, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, env: true }

(The last line of text_generation_launcher: Args seems not correct, we run text-generation-inference in a kubernetes pod and pass args through containers.args)

The real args can be found in the info endpoint:

{
    "model_id": "/data",
    "model_sha": null,
    "model_dtype": "torch.float16",
    "model_device_type": "cuda",
    "model_pipeline_tag": null,
    "max_concurrent_requests": 128,
    "max_best_of": 2,
    "max_stop_sequences": 4,
    "max_input_length": 1000,
    "max_total_tokens": 2048,
    "waiting_served_ratio": 1.2,
    "max_batch_total_tokens": 32000,
    "max_waiting_tokens": 20,
    "validation_workers": 2,
    "version": "0.8.2",
    "sha": "19c41824cb11ba1a3b60a2a65274d8c074383de3",
    "docker_label": "sha-19c4182"
}

Information

Tasks

Reproduction

When the text-generation-inference service generate a <unk> token (which does not have a 'logprob' and I don't know why), The text-generation python client will raise a validation error.

...
  File "d:\workspace\myproj\venv\Lib\site-packages\text_generation\client.py", line 150, in generate
    return Response(**payload[0])
           ^^^^^^^^^^^^^^^^^^^^^^
  File "pydantic\main.py", line 341, in pydantic.main.BaseModel.__init__
pydantic.error_wrappers.ValidationError: 1000 validation errors for Response
details -> tokens -> 0 -> logprob
  none is not an allowed value (type=type_error.none.not_allowed)
...

The corresponding code lay in text_generation/types.py:

# Generated tokens
class Token(BaseModel):
    # Token ID from the model tokenizer
    id: int
    # Token text
    text: str
    # Logprob
    logprob: float # <-- this field
    # Is the token a special token
    # Can be used to ignore tokens when concatenating
    special: bool

# `generate` details
class Details(BaseModel):
    # Generation finish reason
    finish_reason: FinishReason
    # Number of generated tokens
    generated_tokens: int
    # Sampling seed if sampling was activated
    seed: Optional[int]
    # Decoder input tokens, empty if decoder_input_details is False
    prefill: List[InputToken]
    # Generated tokens
    tokens: List[Token]
    # Additional sequences when using the `best_of` parameter
    best_of_sequences: Optional[List[BestOfSequence]]

# `generate` return value
class Response(BaseModel):
    # Generated text
    generated_text: str
    # Generation details
    details: Details

Expected behavior

This error occurs during deserialization of the Response object.

As the response code is 200, I suppose the client should not raise an error?

Maybe we should make 'logprob' optional?

OlivierDehaene commented 1 year ago

The last line of text_generation_launcher: Args seems not correct

What do you mean?

The logprob should always have a value. If it does not something is going wrong, hence the validation error. I will investigate a bit on my side.

edwardzjl commented 1 year ago

We are passing some configuration through command line args, for example --max-input-length 1000 --max-total-tokens 2048, which is reflected in the $BASE_URL/info endpoint, but not in the text-generation-launcher --env command. After some thinking I think it is the correct behavior as the text-generation-launcher --env command reflects the environment not the running process. Sorry for my misleading description.

edwardzjl commented 1 year ago

It seems that this happends when using a very low temperature.

For example, when using temperature 10e-3, The server will respond correctly.

> curl -X POST -H "content-type: application/json" -d '{"inputs": "The sky is blue because", "parameters": {"temperature": 0.001, "max_new_tokens":20}}' "http://localhost:8080/generate"
{"generated_text":" the Earth's atmosphere scatters sunlight in all directions. The scattering is caused by the molecules and particles in"}

However, if I change the temperature to 10e-4, the server will respond empty.

> curl -X POST -H "content-type: application/json" -d '{"inputs": "The sky is blue because", "parameters": {"temperature": 0.0001, "max_new_tokens":20}}' "http://localhost:8080/generate"
{"generated_text":""}
OlivierDehaene commented 1 year ago

If you set a temperature of 10e-4, why not simply use greedy decoding?

edwardzjl commented 1 year ago

I'm new on text generation tasks but I want to lower the "creativity" of the model and stick to stable outputs, and according to my limited knowledge I think should set a low temperature (when using openai I can set the temperature to 0).

But the issue is not about that, it is text-generation-server generates none logprob while the python client of text-generation does not allow that.

OlivierDehaene commented 1 year ago

I'm new on text generation tasks but I want to lower the "creativity" of the model and stick to stable outputs

You should not use any temperature then and stick to greedy decoding.

But the issue is not about that

The issue seems to be about that. If you push the temperature too low, it will cause float overflow and result in nan logprobs.

edwardzjl commented 1 year ago

Thank you for the advice, I will try greedy encoding.

For this issue, I mean if nan logprobs means something goes wrong, maybe we can return http status 400 or something from the server side, with an error message if possible, instead of 200. What do you think?