huggingface / optimum-nvidia

Apache License 2.0
867 stars 86 forks source link

llama.py with fp8 is broken (inference produces garbage results) #71

Open urimerhav opened 7 months ago

urimerhav commented 7 months ago

Hi!

II have a finetuned Llama2 and followed the example/llama.py. When I build the model in fp16, it works just fine, and produces sane results. When we use either the --fp8 or --fp8-cache, the results are garbage (the same chinese character shows on repeat in the completion.

For completness, here's the script I'm calling python build_llama.py --max-prompt-length 1548 --max-new-tokens 500 --fp8 --fp8-cache --max-batch-size 1 /var/datamodels/llama-7b /var/data/llama-7b-optimized

Model loading is done using

self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                use_fp8=True,

Generation is done using

generated, lengths = self.model.generate(
                input_ids=torch.tensor([input_tokens]).to(self.device),
                repetition_penalty=1.0,
                temperature=const.OpenSourceLlmInference.temperature,
                top_k=50,
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=max_gen_tokens
            )

My machine is an H100, so the architecture is supported. I'm using the provided docker image. Cuda and version:

NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.2

Thanks a lot for an awesome contribution, I hope we can figure this out as low latency is mission-critical for us.

urimerhav commented 6 months ago

@kashif @srush @radames @co42

Spamming you all since it's been a while. Should I do something else to help triage this issue?

mfuntowicz commented 6 months ago

The float8 path will have a major rework in 0.1.0b4.

Also, currently it uses a default datasets (cnndaily) with 512 samples which might not fit with your use cases / finetuned model. Did you try providing a custom dataset for calibration?

Thanks

urimerhav commented 6 months ago

So it seems like the repo changed in the time since I wrote this. I was talking about a file called llama.py which used to exist, which optimized a saved llama checkpoint - you had it previously under examples. Now I can't even find it.

To further clarify the use case: we've finetuned our own llama2, with out own data, and want to optimize it for serving. Running the optimization and then generate, it works fine with fp16, and outputs become trash (repeated Chinese character) when using fp8.

We didn't supply a custom dataset, but I can't imagine how with any dataset that includes english in the input and output, we'd wind up with only chinese characters as a valid result from a quantization step.

Withe the new repo code state, I can't even tell how to optimize a llama 2 model anymore. Did you deprecate support for that? If that's still available, the documentation doesn't hint at it and I can't find it in the code.