Open gsujankumar opened 2 years ago
Hi @gsujankumar
Thanks for pointing this issue. I will look into this and send a fix soon. Best, Reza
Hey Reza,
Thanks for looking into this. Can you provide me a rough ETA on the fix?
Sorry for the delay, @gsujankumar I will try to make it work by the end of this week or early next week.
@RezaYazdaniAminabadi Did you find what is causing the issues?
I tried to implement INT8 inference with a MoQ trained BERT model. I noticed that inference with INT8 did not seem to work outside deepspeed>=0.4.0<0.4.3
and with transformers==5.5.2
. With versions beyond these I am either running into errors/accuracy issues.
Hi @gsujankumar,
Sorry for the delay on this line. I was so busy with some internal projects. Yes, you are right that MoQ was mainly targeted for the older version of transformers. The error you are seeing above is coming from the GeMM, and can be also related to the quantization happening before this operation. I will let you know if this is fixed soon.
Thanks, Reza
Hey @RezaYazdaniAminabadi I was able to resolve the issue by casting the FP16 inputs to FP32 in the compute_attention
method in deepspeed/ops/inference/transformer_inference.py
as follows:
Around line 154, add the following lines
if config.q_int8:
qkv_out = qkv_out.float()
and cast the outputs of the function to FP16 before returning by adding the following lines:
if config.q_int8:
context_layer = context_layer.half()
key_layer = key_layer.half()
value_layer = value_layer.half()
return context_layer, key_layer, value_layer
It looks like the method was using FP32 kernels, but still use FP16 data. Computing with FP32 inputs resolved nan
s that resolved CUDA errors.
I am noticing good accuracy with quantization_setting=1
, but with any higher grouping the accuracy is dropping. This is counterintuitive. Is there something that might be a miss?
Hi @gsujankumar,
I am happy you could resolve the issue. Can you please make a PR and add the fix? For the larger number of groups, yes you are right you should get better accuracy, not worse. Are you using our quantizer kernels when quantizing this model? If so, I can double-check if there is any accuracy issue with that. Thanks, Reza
Sure, I will create a PR soon.
Yes, we are using the quantizer kernels from DeepSpeed. Can you check if there are any issues with groups?
Describe the bug I am trying to get started with implementing INT 8 inference on Deepspeed. But I am running into
RuntimeError: CUDA error: an illegal memory access was encountered
.To Reproduce
Code:
I am interested in implementing INT8 inference with GPT2 styled models, the code I am running is the following:
I am running this with
I noticed few bugs blocking INT8 inference and I made the following changes to the source code:
deepspeed/runtime/weight_quantizer.py
as
is_mlp
was not defineddeepspeed/runtime/weight_quantizer.py
deepspeed/ops/inference/transformer_inference.py
Expected behavior Output meaningful logits
ds_report output
Outputs While the code runs error free with
dtype=torch.float
anddtype=torch.half
I am running into errors withdtype=torch.int8
running
CUDA_VISIBLE_DEVICES=1 CUDA_LAUNCH_BLOCKING=1 deepspeed gpt_example.py
results in the following output:Launcher context Using deepspeed launcher