IBM / text-generation-inference

IBM development fork of https://github.com/huggingface/text-generation-inference
Apache License 2.0
52 stars 30 forks source link

Performance Optimizations for TP-Aware GPTQ #67

Open cyang49 opened 6 months ago

cyang49 commented 6 months ago

This is a draft. Please do not merge.

Motivation

Current tgis-native provides GPTQ support for llama and starcoder models by utilizing the fast exllamav2 kernel (and also Marlin #66 when PR is merged). It works well in single GPU deployment. However, for multi-GPU TP deployment, the performance is known to be bad when deploying GPTQ checkpoints that requires activation reordering (desc_act=True in quantization config). This includes many publicly available GPTQ checkpoints.

The reason for the bad performance of these models is that the fast exllamav2 (or Marlin) kernels cannot be used in row-parallel layers, as the weight matrix row shuffling requirement would introduce an extra all-gather communication to do global reordering of input activations of row-parallel layers in TP. The all-gather communication can be prohibitively expensive. As a result, in TGIS, the much slower Triton matmul_248 kernel, which doesn't require shuffling, is used. This 50% CUDA and 50% Triton mixed used in QuantLinear layers works but it is too slow to be a practical solution. vLLM uses a similar approach except that it uses an alternative gptq cuda kernel than the Triton kernel. It still suffers from less optimal performance.

In this PR, we implement TP-aware GPTQ model inference optimizations which includes the technique introduced in the arxiv paper we published previously for the MLP layers, and combining newer technique, masked matmul, for the attention layer optimization.

Preliminary results using exllamav2 show that our techniques enable deploying Llama-70b GPTQ on L40Sx2 getting 24.67 tokens/s, a 30% throughput improvement over deploying FP16 model on A100-80GBx2 (19 tokens/s) thus providing a good cost-saving alternatives for deploying llama-70b. We expect to see even better results using Marlin.

Modifications

The code changes include primarily control path adjustments to manipulate the loading of weight tensors and environment variable flags to toggle different modes.

Known issues:

Result

Prefill Token latency Throughput
FP16: L40Sx4 1.96s 62.33ms 16.04 tokens/s
GPTQ, TP-aware: L40Sx2 2.11s 40.55ms 24.67 tokens/s
GPTQ, original: L40Sx2 3.48s 84.21ms 11.88 tokens/s

We plan to update the results when Marlin PR is merged.

Related Issues

To merge #66 to enable Marlin support