facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.37k stars 6.4k forks source link

Does Dynamic Quantization speed up Fairseq's Transfomer? #1901

Closed kalyangvs closed 2 years ago

kalyangvs commented 4 years ago

What is your question?

As per this tutorial in torch, quantize_dynamic gives speed up of models (though it supports Linear and LSTM as of now). But I am unable to observe speed-up in fairseq's Transformer case whereas Huggingface's BERT gives good speed up.

Is there any suggestion or am I missing something?

Code

model is passed after make_generation_fast in fairseq_cli/generate.py

            quantized_model = torch.quantization.quantize_dynamic(
            model, {torch.nn.Linear}, dtype=torch.qint8
            )

What have you tried?

The example in the model gives around 2X irrespective of no of threads.

What's your environment?

myleott commented 4 years ago

What kind of model are you trying to quantize? It sounds like you're trying to speed up translation models, since you're modifying generate.py. A lot of the runtime here is due to overhead from beam search, since we're generating one token at a time.

BERT doesn't have this issue, since it gets the whole input all at once and produces the entire output all at once. Thus you should generally see bigger speedup with encoder-only models (e.g., BERT) than models with decoders (e.g., translation, language modeling).

Can you try with the fairseq RoBERTa implementation and see if you get a speedup?

Edwardlzy commented 4 years ago

Just want to follow-up on this issue. In my understanding, in generate.py, by default it loads an ensemble (transformers in a list) and you need to change each individual transformer INPLACE to do evaluation on the quantized model. However, when I ran inference on a quantized model, I encountered an error in the forward method in fairseq/modules/multihead_attention.py:

torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
TypeError: expected Tensor as element 0 in argument 0, but got method.

After checking the documentation for quantize_dynamic(), seems like the weights and biases of each trainable layer will be wrapped and can only accessed via LinearPackedParams._weight_bias(). So I changed the multihead_attention.py file to properly access the weights and biases. Then I got this error in torch.nn.functional.py:

output = input.matmul(weight.t())
RuntimeError: Could not run 'aten::mm' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::mm' is only available for these backends: [CUDATensorId, SparseCPUTensorId, VariableTensorId, CPUTensorId, SparseCUDATensorId].

Does this mean quantization is not yet supported in Fairseq? Is there any workaround to run quantization simulation for Transformer models? Thanks in advance!

kalyangvs commented 4 years ago

Though we generate one token at time, for each token generation in decoder side the time would decrease (for non-auto regressive seq-seq), that it supposed to decrease the overall time since transformers are made up of self-attention and linear layers, as we are also quantizing linear layers.

@Edwardlzy I got the same error when ran quantize_dynamic on RoBERTa.

But for a transformer model from here, when ran on GPU: RuntimeError: Could not run 'aten::quantize_per_tensor' with arguments from the 'CUDATensorId' backend. 'aten::quantize_per_tensor' is only available for these backends: [CPUTensorId, VariableTensorId] And on CPU it runs fine without any error.

Edwardlzy commented 4 years ago

@gvskalyan Thanks for the update. Do you mind sharing your code change in details? For example, in generate.py, did you make sure that your quantized model replaced the original one? And did you make any change to other related files like multihead_attention.py and functional.py? I noticed that there has been a new commit for Fairseq yesterday related to quantization. I tried the updated Fairseq and still stuck on the same error: torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), TypeError: expected Tensor as element 0 in argument 0, but got method.

It will be great if @myleott can let us know the right way to do quantization simulations in Fairseq. Thanks!

kalyangvs commented 4 years ago

I have just added model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) here,

And I think the recent commit relating to quantization was intended to be compatible / targeting for torch 1.5.0 release.

Since it is listed as experimental under torch tutorials, maybe @driazati or Ning Dong can know it better.

Edwardlzy commented 4 years ago

@gvskalyan Thank you for the clarification. This is what I tried initially. When I ran evaluation on the "quantized" model in 8-bit, I had exactly the same BLEU score as the float32 model. This may suggest that the model you just quantized does not replace the one in models. What happened was that you created the quantized model in a new variable you called model and it was dropped after this iteration. If you try replacing the for model in models with for i in range(len(models)) and quantizing models[i], I believe you will get the error I encountered.

Quantization is for sure not supported in the current Fairseq version. I have only been able to run inference after rewriting functions in mainly multihead_attention.py from Fairseq and functional.py in torch.nn. It will be great if someone from Fairseq can add support to quantization and I will be happy to help.

kalyangvs commented 4 years ago

yes, I could verify that it gives the same error as you mentioned. Sorry for overlooking through that iteration part.

XiaobingSuper commented 4 years ago

For transformer model which using transformer.wmt14.en-fr dataset, the gemm(linear) only make up about 41% of the whole workload, so even using dynamic quantization, it can not have such high speed up as BERT(gemm can make up 70%~80%, int8 path can get 1.8x speed up). Pre-trianed model: https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2 dataset: https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2 run: fairseq-generate data-bin/wmt14.en-fr.joined-dict.newstest2014 --path data-bin/wmt14.en-fr.joined-dict.transformer/model.pt --batch-size 128 --beam 5 --remove-bpe --cpu --quiet

profiler result:

Name                           Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls 
-----------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
aten::mm                       34.30%           83.842s          34.38%           84.041s          1.583ms          53092
aten::index_select             23.71%           57.961s          23.83%           58.261s          2.043ms          28516
aten::bmm                      7.09%            17.332s          7.18%            17.556s          682.373us        25728
aten::_log_softmax             6.38%            15.589s          6.38%            15.602s          14.718ms         1060
aten::_cat                     5.70%            13.941s          5.78%            14.125s          1.110ms          12722
aten::topk                     5.40%            13.191s          10.86%           26.549s          6.333ms          4192
aten::nonzero                  4.44%            10.842s          4.44%            10.842s          974.253us        11129
aten::add                      1.24%            3.021s           1.28%            3.127s           72.651us         43036
aten::_softmax                 0.95%            2.316s           0.97%            2.366s           183.941us        12864
aten::copy_                    0.86%            2.091s           0.86%            2.092s           2.981us          701615
aten::add_                     0.84%            2.042s           0.84%            2.042s           38.176us         53488
aten::eq                       0.63%            1.547s           1.61%            3.940s           7.203us          546994
aten::uniform_                 0.62%            1.510s           0.62%            1.510s           5.698ms          265
aten::lt                       0.61%            1.484s           1.97%            4.805s           8.880us          541132
aten::div_                     0.56%            1.374s           0.58%            1.414s           534.152us        2648
aten::to                       0.53%            1.292s           1.04%            2.537s           5.341us          474982
aten::ne                       0.52%            1.270s           1.05%            2.566s           156.480us        16396
aten::select                   0.50%            1.214s           0.63%            1.542s           2.003us          769526
aten::view                     0.50%            1.214s           0.50%            1.214s           2.360us          514185
aten::_local_scalar_dense      0.48%            1.181s           0.48%            1.181s           0.971us          1216111
aten::empty                    0.38%            940.632ms        0.38%            940.632ms        1.033us          910608
aten::native_layer_norm        0.37%            907.179ms        0.41%            1.005s           51.896us         19368
aten::item                     0.36%            881.843ms        0.84%            2.063s           1.696us          1216111
aten::threshold                0.30%            744.394ms        0.32%            780.607ms        120.020us        6504
aten::normal_                  0.28%            682.135ms        0.28%            682.135ms        341.067ms        2
aten::as_strided               0.24%            576.000ms        0.24%            576.000ms        0.556us          1035151
aten::resize_                  0.21%            502.591ms        0.21%            502.591ms        0.767us          655405
aten::matmul                   0.19%            470.515ms        34.83%           85.133s          1.604ms          53092
aten::empty_strided            0.15%            374.766ms        0.15%            374.766ms        0.901us          416104
aten::masked_fill_             0.15%            373.153ms        0.15%            373.153ms        57.373us         6504
aten::is_nonzero               0.14%            351.842ms        0.54%            1.325s           2.444us          542072
aten::transpose                0.14%            331.905ms        0.19%            464.422ms        4.293us          108184
aten::t                        0.12%            291.702ms        0.19%            471.598ms        8.883us          53092
aten::mul_                     0.11%            265.367ms        0.19%            457.671ms        35.980us         12720
aten::slice                    0.10%            245.073ms        0.13%            318.686ms        3.248us          98110
aten::unbind                   0.09%            220.028ms        0.23%            559.496ms        46.578us         12012
aten::_unsafe_view             0.07%            179.628ms        0.14%            348.094ms        6.556us          53092
aten::index                    0.07%            159.936ms        0.09%            214.620ms        22.426us         9570
aten::mul                      0.06%            149.981ms        0.07%            167.732ms        61.848us         2712
aten::is_floating_point        0.04%            98.035ms         0.04%            98.035ms         0.282us          347256
aten::floor_divide             0.04%            96.478ms         0.05%            118.777ms        6.249us          19008
aten::softmax                  0.04%            88.753ms         1.01%            2.477s           192.542us        12864
aten::layer_norm               0.04%            88.604ms         0.45%            1.101s           56.828us         19368
aten::unsqueeze                0.04%            85.932ms         0.04%            99.115ms         6.851us          14468
aten::contiguous               0.03%            79.695ms         0.08%            191.545ms        2.001us          95724
aten::_index_put_impl_         0.03%            76.442ms         4.48%            10.951s          1.984ms          5519
aten::cat                      0.03%            74.303ms         5.81%            14.200s          1.116ms          12722
aten::sum                      0.03%            72.356ms         0.05%            114.064ms        35.512us         3212
aten::stride                   0.03%            71.803ms         0.03%            71.803ms         0.312us          230025
aten::expand                   0.02%            60.596ms         0.03%            72.795ms         4.571us          15924
aten::masked_fill              0.02%            56.244ms         0.33%            800.293ms        123.046us        6504
aten::narrow                   0.02%            49.892ms         0.05%            121.161ms        4.868us          24890
aten::clone                    0.02%            44.050ms         0.13%            328.826ms        34.588us         9507
aten::any                      0.02%            43.576ms         0.02%            57.998ms         11.374us         5099
aten::gather                   0.02%            41.513ms         0.02%            50.882ms         12.278us         4144
aten::relu                     0.01%            31.306ms         0.33%            811.913ms        124.833us        6504
aten::type_as                  0.01%            29.910ms         0.02%            50.471ms         3.302us          15284
aten::sort                     0.01%            28.267ms         0.03%            64.748ms         21.561us         3003
aten::fill_                    0.01%            28.139ms         0.01%            28.139ms         2.796us          10063
aten::bitwise_not              0.01%            26.371ms         0.02%            55.272ms         6.669us          8288
aten::empty_like               0.01%            24.928ms         0.02%            49.340ms         6.284us          7852
aten::reshape                  0.01%            23.228ms         0.02%            39.985ms         2.186us          18293
aten::arange                   0.01%            22.469ms         0.02%            45.343ms         5.857us          7742
aten::masked_select            0.01%            21.942ms         0.03%            67.980ms         34.264us         1984
aten::fmod                     0.01%            21.843ms         0.02%            47.664ms         22.483us         2120
aten::index_put_               0.01%            19.417ms         4.49%            10.970s          1.988ms          5519
aten::bitwise_and              0.01%            16.534ms         0.01%            32.741ms         7.722us          4240
aten::gt                       0.01%            15.082ms         0.01%            25.092ms         11.574us         2168
aten::sub_                     0.01%            12.618ms         0.01%            23.736ms         22.392us         1060
aten::log_softmax              0.00%            11.895ms         6.39%            15.615s          14.731ms         1060
aten::ge                       0.00%            11.715ms         0.01%            18.908ms         9.126us          2072
aten::mean                     0.00%            10.056ms         0.04%            105.505ms        99.533us         1060
aten::embedding                0.00%            7.471ms          0.05%            128.512ms        118.553us        1084
aten::all                      0.00%            7.244ms          0.00%            10.563ms         10.196us         1036
aten::sub                      0.00%            6.810ms          0.00%            8.088ms          8.754us          924
aten::_shape_as_tensor         0.00%            6.142ms          0.00%            10.460ms         9.650us          1084
aten::div                      0.00%            5.818ms          0.00%            8.726ms          60.601us         144
detach_                        0.00%            5.345ms          0.00%            5.345ms          1.007us          5310
aten::detach_                  0.00%            4.205ms          0.00%            9.550ms          1.798us          5310
aten::__and__                  0.00%            3.779ms          0.01%            24.847ms         11.720us         2120
aten::sin                      0.00%            1.880ms          0.00%            3.744ms          936.035us        4
aten::ones                     0.00%            1.296ms          0.00%            2.932ms          7.405us          396
aten::zeros                    0.00%            1.014ms          0.00%            3.510ms          16.248us         216
aten::cos                      0.00%            674.841us        0.00%            1.343ms          335.792us        4
aten::zero_                    0.00%            590.303us        0.00%            2.296ms          9.334us          246
aten::repeat                   0.00%            430.778us        0.00%            1.497ms          62.367us         24
aten::_cumsum                  0.00%            404.558us        0.00%            544.045us        22.669us         24
detach                         0.00%            370.571us        0.00%            370.571us        1.333us          278
aten::cumsum                   0.00%            271.373us        0.00%            1.202ms          50.074us         24
aten::dropout                  0.00%            265.072us        0.00%            265.072us        1.841us          144
aten::unfold                   0.00%            241.803us        0.00%            361.766us        7.537us          48
aten::set_                     0.00%            226.966us        0.00%            226.966us        1.220us          186
aten::detach                   0.00%            220.686us        0.00%            591.257us        2.127us          278
aten::exp                      0.00%            215.900us        0.00%            418.325us        104.581us        4
aten::alias                    0.00%            34.403us         0.00%            34.403us         1.433us          24
aten::expand_as                0.00%            28.514us         0.00%            72.917us         3.038us          24
-----------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Self CPU time total: 244.442s
nicolabertoldi commented 3 years ago

@gvskalyan time ago you mentioned you were able to run inference with a dynamically quantized transformer. You changed some parts of the fairseq code. I tried something similar, but I failed/

Could you please share your modified code?

Thanks in advance.

kalyangvs commented 3 years ago

@nicolabertoldi Sorry, I was not able to run with dynamically quantized transformer with the then-present codebase. I did not try using the latest version and any other alternate ways for dynamic quantization for fairseq's transformer.

Huggingface repo seems to have fairseq model supported, please try using onnxruntime with the model there, you might get an improvement.

nicolabertoldi commented 3 years ago

@gvskalyan

To your knowledge there is no way to use quantization with transformer using the latest version of pytorch and fairseq. Correct?

Have you found alternative ways to do that?

stale[bot] commented 3 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

HamidShojanazeri commented 3 years ago

I have just added model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) here,

  • I have seen the no of params using p.numel(), ideally no of parameters should not differ - they should get quantized, may be numel can not count these - an indication model was quantized
  • I have printed the model - instead of linear it is DynamicQuantizedLinear everywhere.
  • How did the correct variable get replaced and only quantized_model is being used? I have gone through the code and also when --cpu flag is removed (to run on GPU) it gives the above mentioned error

And I think the recent commit relating to quantization was intended to be compatible / targeting for torch 1.5.0 release.

Since it is listed as experimental under torch tutorials, maybe @driazati or Ning Dong can know it better.

It is so late to comment on this issue, but just for reference, the quantization streams doesn't support CUDA at this point.

Oxi84 commented 3 years ago

So quantization will not work on bert too if it is used on cuda?

It does not work for me too, but i only tried to use it on the model directly:

from fairseq.models.transformer import TransformerModel
ru2en = TransformerModel.from_pretrained(
    '/root/models/deen/wmt19.de-en.joined-dict.ensemble/',
    checkpoint_file='model1.pt',
    #checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
    data_name_or_path='/root/models/deen/wmt19.de-en.joined-dict.ensemble/',
    bpe='fastbpe',
    bpe_codes='/root/models/deen/wmt19.de-en.joined-dict.ensemble/bpecodes'
    )

ru2en = torch.quantization.quantize_dynamic( ru2enl, {torch.nn.Linear}, dtype=torch.qint8 )

File "stringsource", line 2, in fastBPE.fastBPE.__reduce_cython__
TypeError: self.c_obj cannot be converted to a Python object for pickling
HamidShojanazeri commented 3 years ago

@Oxi84 that's right, Pytorch quantization schemes are not working for cuda at this point, it's under development, for now the best options in Pytorch land are using model.half() or using mix precision for cuda.

madelagua commented 2 years ago

What's the best solution to apply dynamic quantization to this date? I guess we can either convert the models to ONNX or apply the modifications @Edwardlzy mentioned.

Are there any plans on supporting dynamic quantization or model conversion to ONNX? Also @Edwardlzy , you mentioned you managed to run inference, could you please share the modifications you did?

Thanks in advance!

stale[bot] commented 2 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

stale[bot] commented 2 years ago

Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!