deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.05k stars 648 forks source link

Support for FP8 quantization with TensorRT-LLM #3145

Open nathan-az opened 3 months ago

nathan-az commented 3 months ago

DJL does not support (or has not documented support) for FP8 quantization (docs).

FP8 is currently TensorRT-LLM's recommended quantization technique, with the lowest performance degradation with good speedup.

It would be great to support this in DJL. It should not affect any APIs other than adding options (I expect adding option.quantization=fp8).

Any users seeking a speedup or lower memory footprint would benefit from this change.

Note This does contradict an AWS blogpost but I expect this is an inaccuracy.

as part of the latest LMI DLC release (0.25.0), enabling state-of-the-art optimizations like SmoothQuant, FP8, and continuous batching

ydm-amazon commented 3 months ago

Hi Nathan, for FP8 quantization, there are two currently offered choices - SmoothQuant and AWQ.

For SmoothQuant for example, to enable FP8 smoothquant, the options you can add are

option.quantize = smoothquant
option.smoothquant_alpha = 0.8
option.smoothquant_per_channel = true
option.smoothquant_per_token = true
option.dtype = fp8
nathan-az commented 3 months ago

Ah thanks @ydm-amazon - I was aware of both, but am concerned about the quality difference in the model outputs given the reported MMLU decrease of SmoothQuant versus the "native" FP8. TGI recently added fp8 but indicate it only works on Hopper architecture onward. I suppose because it's the first architecture that natively supports fp8 operations.

Couple of follow-up questions:

  1. Are there still plans to support this "native" fp8 mentioned in TRT-LLM and recently added to TGI?
  2. Can you confirm that when using smoothquant that dtype should be set to fp8? The examples in the DJL docs seem to keep option.dtype = fp16 when using both smoothquant and awq.
  3. I'm not sure smoothquant, but I believe AWQ requires calibration, and there are two option. parameters regarding calibration. Which dataset is used as the calibration set for calibrated quantization methods if we use JIT engine compilation? Is it possible to pack a calibration dataset with model files for JIT AWQ compilation if needed?