cjg91 / trans-fat

An FPGA Accelerator for Transformer Inference
72 stars 12 forks source link

Quantizing Bert #3

Open cjg91 opened 2 years ago

cjg91 commented 2 years ago

This issue tracks progress and understanding of the quantization problem. We will select a downstream task to fine-tune BERT and perform quantization-aware training. Then, we will take these quantized weights and implement the forward pass using pure Tensor functions to deobfuscate operations.

Tasks:

Questions:

The fine-tuning tasks that I select: MRPC ML Perf uses SQuAD 1.1 as the downstream task. They also use BERT-Large, so we can't compare apples to apples.

The quantization technique I use: TBD

Resources: PyTorch quantization intro PyTorch quantization docs Quantization-supported operations BERT dynamic quantization Tutorial Tensorflow quantized transformer paper

cjg91 commented 2 years ago
Model Task Quantization Accuracy F1 score
textattack/roberta-base-MRPC GLUE MRPC None 87.2
Dynamic Quantization 86.6
My Dynamic Quantization 86.8
kssteven/ibert-roberta-base GLUE MRPC INT8 static (fine-tuned) 0.3162 0.0
cjg91 commented 2 years ago

Pytorch Quantization Update

There is trouble with static quantization. By default, qconfig uses torch.qint8 datatype for quantization. nn.Embedding only supports torch.quint8 quantization. To try to fix this, I am setting the embedding qconfig to use torch.quint8 dtype.

Another possible solution is to allow the embedder to be unquantized and use QuantStub in the encoder layer to quantize activations. This would be fine for our uses, but the difficulty here is that pure matmul is not implemented for quantized tensors. You would have to use a quantized, functional linear layer to do matmul. Requires further investigation.

It's unclear what combination of dynamic quantization and static quantization we will use in the end. Most important is to get the encoder layer to be fully quantized and verify that we understand its operations by mimicking its computation with numpy functions.

Even when implementing a quantized region around the encoder layers there is trouble with PyTorch static quantization. Since there is no supported quantized matmul operation, I tried to us torch.nn.quantization.functional.linear to implement matmul, but it expects a quint8 input and qint8 weights. Both input and weights were qint8, and I'm not sure how you go about fixing that.

Other Options

  1. Use MLPerf reference implementation that does INT8 quantization

  2. I-BERT Someone already made an int-quantized Bert. I didn't use it at first because I didn't want to pretrain on a task but it's clear how easy that is now.

  3. QDQBERT is another huggingface implementation that uses techniques similar to MLPerf reference implementation. This option is starting to look best.

cjg91 commented 2 years ago

I-BERT Notes

I-BERT requires finetuning on a downstream task, and the default is MRPC.

Got terrible accuracy results on MRPC (0.33 accuracy) when quantization-aware finetuning for 1 epoch. It's possible that the huggingface implementation has some bug that the original fairseq does not.

Looking at the code, I-BERT does not use percentile-based quantization. It uses min and max instead. This could be a reason for terrible performance.

cjg91 commented 2 years ago

Custom Quantization Notes

In the file bert_sw/dynamic_quant_ops.py, a dynamically-quantized matmul and linear layer are implemented. When you modify an mrpc-trained RoBERTa, in bert_sw/dynamic_quant_roberta.py, to use these operations, you get 90.2% accuracy, down from 91.2%. To compare, using PyTorch dynamic quantization on the same model gets you 90.4% accuracy (note that torch.quantization.quantize_dynamic does not quantize matmuls).

The reason I was getting poor accuracy in my dynamic implementation earlier is because I used the 99.9 and 0.1 percentiles to find bounds for quantization. This caused accuracy to fall to 74%. min and max perform much better.

cjg91 commented 2 years ago

MLPerf Quantization Notes

Using NVIDIA's PyTorch quantization with QAT to get a quantized model and figure out how to implement the quantized ops.

Can't do it on viz because wrong CUDA version. Can't do it on hyperion because wrong GPU for PyTorch.

Intrepid can do QAT to make a bert-base size quantized model. That is running now. The output should be a folder of npz files containing layer weights. Will these contain the scaling factors as well?

Also, it would be useful to figure out how to replicate MLPerf inference with this. Ideally, use onnx output or something to generate ground truth that our quantized implementation is functionally correct.