NVIDIA / FasterTransformer

Transformer related optimization, including BERT, GPT
Apache License 2.0
5.87k stars 893 forks source link

INT8 Support for GPT models #265

Open bharatv007 opened 2 years ago

bharatv007 commented 2 years ago

I see that there is full int8 support (both weights and activations) for BERT, its not clear to me what is supported for GPT models (here). Ideally if we can support the transformer block layers in Int8, it will save a lot of memory and computation (with some overhead from quant and dequant). Can you elaborate on what features are supported now and if and when can we expect full support for int8. Also would like to understand what are the hurdles with int8 implementation that one could face :). Thanks in advance.

byshiue commented 2 years ago

The cost of using QAT on GPT model is high, so this feature is still under discussing.

bharatv007 commented 2 years ago

Thanks for getting back.

  1. Can you elaborate on why there is an overhead in GPT case while BERT numbers suggest gains?
  2. Even if there are overheads, it would def help with reduction in memory (can use fewer GPUs), can you also elaborate on memory gains? In DS paper (section 5.4) they mention speedups for decoding/generating 50 tokens. They also mention they reduce number of GPUs.
byshiue commented 2 years ago
  1. Because the model size of GPT is much larger than BERT.
  2. But they don't mention the QAT cost and the accuracy.
bharatv007 commented 2 years ago

The cost of using QAT on GPT model is high, so this feature is still under discussing.

When you mean cost here, you mean the training cost and not inferece cost? Are there any inference cost (overhead) that may occur for int8 GPT inference?

byshiue commented 2 years ago

When you mean cost here, you mean the training cost and not inferece cost? Are there any inference cost (overhead) that may occur for int8 GPT inference?

Yes, I mean the training and finetuning cost. For additional cost for inference, we can hide most of them if the workflow are fully INT8.

bharatv007 commented 2 years ago

When you mean cost here, you mean the training cost and not inferece cost? Are there any inference cost (overhead) that may occur for int8 GPT inference?

Yes, I mean the training and finetuning cost. For additional cost for inference, we can hide most of them if the workflow are fully INT8.

Currently parallelGPT seems to only support weight quantization (only for decoder, not context? correct me if I am wrong), does fully int8 refer to activation quantization too?

byshiue commented 2 years ago

Yes. If we hope to run fully int8, we should quantize inputs/outputs of all kernels and GEMMs.

julian-q commented 2 years ago

We managed to use all int8 weights for GptContextAttentionLayer, DecoderSelfAttentionLayer, and FfnLayer using int8WeightPerChannelLdkMultiplicationLauncher. Since this function only supports m = 1 and m = 2, we used for-loops when the first dimension of the input was greater than 2.

Using this trick, we were able to achieve decreased max CUDA memory usage. We freed the unused full-precision weights after int8 quantization in examples/pytorch/gpt/utils/gpt.py. This will allow us to serve GPT models using fewer GPUs.

However, using a for-loop for the int8 weight matmuls is very slow and we are trying to parallelize it with a kernel that supports arbitrary m. Do you have any tips for redesigning int8WeightPerChannelLdkMultiplication so it works with m > 2?

byshiue commented 2 years ago

We managed to use all int8 weights for GptContextAttentionLayer, DecoderSelfAttentionLayer, and FfnLayer using int8WeightPerChannelLdkMultiplicationLauncher. Since this function only supports m = 1 and m = 2, we used for-loops when the first dimension of the input was greater than 2.

Using this trick, we were able to achieve decreased max CUDA memory usage. We freed the unused full-precision weights after int8 quantization in examples/pytorch/gpt/utils/gpt.py. This will allow us to serve GPT models using fewer GPUs.

However, using a for-loop for the int8 weight matmuls is very slow and we are trying to parallelize it with a kernel that supports arbitrary m. Do you have any tips for redesigning int8WeightPerChannelLdkMultiplication so it works with m > 2?

The function is only faster than FP16 when m <= 2. Otherwise, the speed of INT8 kernel would be slower than FP16 due to the casting and scaling.

Tecmus commented 2 years ago
  1. Because the model size of GPT is much larger than BERT.
  2. But they don't mention the QAT cost and the accuracy.

How long will QAT take , compare with general GPT finetuning? Are there some comparation metrics?

byshiue commented 2 years ago

We don't have such study. If you have seen any studies about QAT of GPT, it is welcome to share to us.

ImanHosseini commented 2 years ago

GLM-130B: An Open Bilingual Pre-trained Model New results on applying quantization to GPT models. Do you have any suggestions -or say, a blueprint- on how to adapt the current GPT-J implementation to support INT8? (I have been looking at bert & bert_int8 to try to see how you did it for bert but GPT-J code is more complicated)

ImanHosseini commented 2 years ago

A question regarding INT8 support: Q1. Almost all templated classes (e.g. https://github.com/NVIDIA/FasterTransformer/blob/6fddeac5f59ce4df380002aa945da57a0c8e878c/src/fastertransformer/models/gpt/GptDecoderLayerWeight.cc#L201) only support float or half:

template struct GptDecoderLayerWeight<float>;
template struct GptDecoderLayerWeight<half>;

Assuming one wants to implement INT8 for GPT-J, would the first step be to implement "int8_t" versions of these components? (e.g. like how there is int8 version for layernorm here:https://github.com/NVIDIA/FasterTransformer/blob/bc214067d603bf95beb8af5d1ce5960e96ba7244/src/fastertransformer/kernels/layernorm_int8_kernels.h) Can you break down (roughly) the steps to implement INT8 for GPT-J? Q2. I see that in case of 'bertINT8Example' the inputs are either fp16 or fp32, and later there is this: https://github.com/NVIDIA/FasterTransformer/blob/9f249ed83545912e21665123a97b9aaab3c3fa44/src/fastertransformer/models/bert_int8/BertLayerINT8.cc#L211 To invoke quantization which boils down to "cvt.rni.sat.s8.f32". My question is: for example in Coral Edge TPUs the model they get as input, it is already using int8 all around -rather than fp16/32-, if we are going to use int8, why not save weights as int8 to begin with, rather than doing fp32 and then doing a conversion? Like even here: https://github.com/NVIDIA/FasterTransformer/blob/bc214067d603bf95beb8af5d1ce5960e96ba7244/src/fastertransformer/models/bert_int8/BertLayerINT8Weight.h#L34 These 'deviceMalloc' calls allocate stuff using sizeof(T) -> which is fp16/32 and not int8_t. Q3. Comment is truncated here?: - because the? the what? - https://github.com/NVIDIA/FasterTransformer/blob/bc214067d603bf95beb8af5d1ce5960e96ba7244/src/fastertransformer/layers/attention_layers_int8/FusedAttentionLayerINT8.h#L31

byshiue commented 2 years ago

Unlike FP16. INT8 often requires mixed precision. For example, to run nccl allReduce, we may need to pass fp16/bf16 instead of INT8. So, the output of previous gemm may be set to fp16/bf16.

Also, different workflow of INT8 lead to different accuracy. We have many options. For example, using per-tensor scale, per-channel scale or per-group scale to quantize the weight/input? Do we need to keep some parts as fp16/bf16 to maintain the accuracy? It highly depends on the workflow you have.

If you want to quantize all input/weight/output to INT8 by per-tensor (like BERT), then you should quantize the weight from fp16/bf16 to int8 at the beginning, and quantize the input in some kernel. When you run GEMM, you need to provide the scales to let GEMM dequantize them back to INT32 to compute, and quantize again to INT8 as output.

For custom CUDA kernel like layernorm, we hope to use INT8 I/O in most time. So, the kernel receives INT8 inputs, do dequantize to FP32 in the kernel, run layernorm to get FP32 results, and run quantize again to quantize the result to INT8.

You can refer the INT8 workflow of BERT at https://github.com/NVIDIA/FasterTransformer/blob/main/docs/bert_guide.md#model-architecture.

Answer for Q1: you can follow the idea to support something like GPTJINT8, which means that you use float as intermediate results or cumulative data type, but most inputs/outputs are INT8.

Answer for Q2: We don't assume the input can be INT8. So, we need to quantize it at the beginning. In most time, the weights are still fp32/fp16 because we cannot really run INT8 workflow on training side. In most framework, we only run fake-quantization to simulate the results of INT8 to get the scales.

tangbinh commented 1 year ago

@byshiue Would you mind confirming whether weight-only quantization works with the GPT 175B model without mixed precision? I have been able to get reasonable outputs using a OPT 175B checkpoint, but it requires an extra copy of weights in FP16 (i.e. freeing some FP16 weights in this fashion results in garbage outputs). Interestingly, for smaller OPT models (e.g. 125M), it works perfectly fine without that FP16 copy of some weights. Do you have any idea what might cause the issue?

It might be unrelated to this issue, after the large commit that you recently introduced, I can't make the OPT 175B checkpoint work even without quantization as it keep producing negative output indices. Strangely, smaller OPT models (e.g. 125M) work just fine again. Have you tested the new commit with any large checkpoint (e.g. GPT, OPT)?

Lastly, do you expect the new SmoothQuant integration (e.g. int8_mode = 2) to work with GPT / OPT? Are there more instructions besides the brief paragraph in the doc (e.g. the PyTorch example still doesn't allow int8_mode = 2)? Thank you for your help in advance.

byshiue commented 1 year ago

In older code (v5.1), the weight only is only used in generation. When computing the cache of context, FT still uses FP16. So, when you free the FP16 weights, it leads to garbage result.

In latest codes, weight only can be used in all cases and we don't need to keep the FP16 weights.

I can't make the OPT 175B checkpoint work even without quantization as it keep producing negative output indices.

Can you provide the reproduced steps and the issue you observer? We have verified the correctness and not find any issue, but it may be related to environment and hardware.

For SmoothQuant, we have verified on OPT now. It requires scales of activations, you can provide it or do calibration by the converter, which is demonstrated in the document.

erichan1 commented 1 year ago

Hi! Interested in int8 as well. To clarify for int8_mode=1 (weight-only int8), during inference are you casting int8 weights back to fp16, doing an fp16 linear, then deleting the fp16 weight?

For smoothquant, is the calibration done with calibration.py in the smoothquant repo?

For both of these, I'd appreciate a pointer to the int8 kernels in FasterTransformer.

byshiue commented 1 year ago

weight-only int8 dequantizes the weight from the int8 to fp16 in the GEMM.

For smoothquant, that's correct and you can ask more details in smoothquant repo.

You can find the wrapper of these gemms and find their implementation.

erichan1 commented 1 year ago

Just to confirm for GPT models, int8_mode=1 is the one shown here https://github.com/NVIDIA/FasterTransformer/blob/main/docs/images/workflow-of-int8-inference.png? But int8_mode=2 is switched to smoothquant in GPT models instead of the one shown in the picture?

As a follow up, can we expect int8_mode=1 to be accurate for GPT models, given that LLM int8 showed that there are large column-wise outliers in the activations? It seems like int8_mode=1 is not considering outliers.

chenho74 commented 1 year ago

(not the author, but) I think the graph was for int8 BERT. The two int8 modes in BERT are defined differently than those in GPT. There hasn't been a graph for int8 GPT that I have seen in the repo. +1 for add a graph for GPT please :D especially with the new 5.3 features

erichan1 commented 1 year ago

(not the author, but) I think the graph was for int8 BERT. The two int8 modes in BERT are defined differently than those in GPT. There hasn't been a graph for int8 GPT that I have seen in the repo. +1 for add a graph for GPT please :D especially with the new 5.3 features

Thanks @chenho74. I understand int8_mode=2 (smoothquant) since there's a paper for it. But it's unclear to me exactly what's going on for int8_mode=1. Is it simply that we quantize weights down to int8 beforehand, and then dequantize them at inference-time back to fp16 and perform fp16 gemm? However in tests that we're running int8_mode=1 is faster than fp16, so this can't be. So now I'm confused if we're running an int8 gemm (which shouldn't work if we don't deal with outliers, as shown in llm int8 paper). Would love a clarification from @byshiue.

byshiue commented 1 year ago

int8_mode = 1 is weight only quantization. As @erichan1 says, we quantize the weight to int8 during converting weight. During inference, we load int8 weight, dequantizing it to fp16 in the gemm, and use fp16 tensor core.

erichan1 commented 1 year ago

int8_mode = 1 is weight only quantization. As @erichan1 says, we quantize the weight to int8 during converting weight. During inference, we load int8 weight, dequantizing it to fp16 in the gemm, and use fp16 tensor core.

Ok, so to clarify, it's the same as Smoothquant's W8A16 linear here? Makes sense that this would retain accuracy. The speed is what's surprising. We're seeing increases in speed over fp16, which I thought would be impossible, since we're still doing an fp16 gemm and need to do extra ops for dequantize. It's probably because we're decreasing model parallelism by reducing weights by using int8_mode=1. But for comparison, the unoptimized W8A16 kernel in torch-int is something like 5-10x slower than regular fp16 linear. Awesome that FT implementation is this optimized.

Is it possible to get this kernel as a separate PyTorch op? Or at least a pointer to where it is... having a hard time finding where it lives in the FT source. I think this would be a great generic kernel for LLM developers.

byshiue commented 1 year ago

It is same to W8A16.

When the batch size is small, GEMM is often memory bounded. So, reducing the weight size is helpful. However, when batch size is large enough, GEMM become compute bounded and the additional casting from INT8 to FP16. So, the performance would be worse than pure FP16.

The kernels are implemented by cutlass. We don't have plan to encapsulate such kernels by PyTorch op now.

songkq commented 1 year ago

int8_mode = 1 is weight only quantization. As @erichan1 says, we quantize the weight to int8 during converting weight. During inference, we load int8 weight, dequantizing it to fp16 in the gemm, and use fp16 tensor core.

image

@byshiue @erichan1 Hi, I‘m confusing how to quantize the weight to INT8 during converting weight for INT8 weight only quantization. As shown in the example, it only supports to convert weight to FP32/FP16 data type.
https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gpt/utils/nemo_ckpt_convert.py https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gpt/utils/huggingface_gpt_convert.py

Therefore, I just need to convert weight to FP16 and set the int8_mode=1 (correct me if I am wrong), FasterTransformer executes loading and quantizing weight from file automatically?

image