NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.34k stars 1.39k forks source link

BMM much slower with mixed precision #546

Open tsdalton opened 4 years ago

tsdalton commented 4 years ago

I have a Seq2Seq network with attention and when training with Apex/O1 optimization I notice that mixed precision is more than 3x slower. It seems that BMM is the culprit. Any ideas why this is happening?

Pytorch 1.3 CUDA 10.1 CuDNN 7.6.4 NVIDIA Tesla K80

FP32

--------------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------  
Name                                          Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  Input Shapes                         
--------------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------  
AddmmBackward                                 0.21%            30.304ms         1.70%            246.426ms        247.913us        17.36%           4.517s           4.545ms          994              [[400, 11192]]                       
mm                                            1.15%            167.055ms        1.15%            167.055ms        168.063us        9.73%            2.533s           2.548ms          994              [[11192, 400], [400, 128]]           
mm                                            0.20%            28.216ms         0.20%            28.216ms         28.386us         7.55%            1.965s           1.977ms          994              [[400, 11192], [11192, 128]]         
addmm                                         0.43%            61.957ms         0.43%            61.957ms         62.331us         4.45%            1.158s           1.165ms          994              [[11192], [400, 128], [128, 11192],  
CudnnRnnBackward                              0.36%            51.642ms         11.41%           1.651s           1.661ms          3.24%            843.150ms        848.240us        994              [[400, 64], [], []]                  
_cudnn_rnn_backward                           11.06%           1.599s           11.06%           1.599s           1.609ms          3.22%            838.949ms        844.013us        994              [[400, 64], [99840], [4, 400, 64],   
LogSoftmaxBackward                            0.00%            130.929us        0.04%            5.522ms          690.293us        3.01%            783.753ms        97.969ms         8                [[400, 11192, 100]]                  
_log_softmax_backward_data                    0.04%            5.391ms          0.04%            5.391ms          673.927us        3.01%            783.718ms        97.965ms         8                [[400, 11192, 100], [400, 11192, 10  
log_softmax                                   0.00%            98.397us         0.02%            2.882ms          360.198us        2.86%            744.305ms        93.038ms         8                [[400, 11192, 100]]                  
_log_softmax                                  0.02%            2.783ms          0.02%            2.783ms          347.898us        2.86%            744.271ms        93.034ms         8                [[400, 11192, 100]]                  
gru                                           0.31%            45.028ms         3.23%            466.632ms        469.449us        2.28%            594.649ms        598.238us        994              [[400, 64], [1], [4, 400, 64]]       
_cudnn_rnn                                    2.80%            405.360ms        2.80%            405.360ms        407.807us        2.08%            541.182ms        544.448us        994              [[400, 64], [99840], [4, 400, 64], 
--------------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------  
Self CPU time total: 14.464s
CUDA time total: 26.026s

Mixed Precision (O1)

-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------  
Name                                 Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  Input Shapes                         
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------  
BmmBackward                          0.05%            22.959ms         15.90%           7.577s           9.471ms          22.03%           19.553s          24.441ms         800              [[400, 100, 1]]                      
bmm                                  24.62%           11.735s          24.62%           11.735s          7.334ms          13.23%           11.746s          7.341ms          1600             [[400, 100, 1], [400, 1, 64]]        
bmm                                  11.36%           5.415s           11.36%           5.415s           6.768ms          10.98%           9.745s           12.181ms         800              [[400, 64, 100], [400, 100, 1]]      
BmmBackward                          0.24%            115.837ms        45.89%           21.867s          22.590ms         6.96%            6.181s           6.386ms          968              [[400, 1, 64]]                       
AddmmBackward                        0.34%            161.317ms        0.78%            372.953ms        385.282us        4.11%            3.646s           3.767ms          968              [[400, 11192]]                       
bmm                                  17.33%           8.259s           17.33%           8.259s           10.324ms         3.55%            3.149s           3.936ms          800              [[400, 1, 64], [400, 64, 100]]       
mm                                   0.13%            61.531ms         0.13%            61.531ms         63.565us         3.37%            2.994s           3.093ms          968              [[400, 11192], [11192, 128]]         
bmm                                  3.14%            1.495s           3.14%            1.495s           1.868ms          2.63%            2.334s           2.917ms          800              [[400, 100, 64], [400, 64, 1]]       
bmm                                  2.96%            1.411s           2.96%            1.411s           1.763ms          2.58%            2.288s           2.860ms          800              [[400, 1, 100], [400, 100, 64]]      
BmmBackward                          0.01%            2.488ms          1.76%            840.919ms        9.778ms          2.46%            2.187s           25.428ms         86               [[400, 86, 1]]                       
BmmBackward                          0.00%            2.380ms          1.63%            777.998ms        9.488ms          2.27%            2.012s           24.542ms         82               [[400, 82, 1]]                       
bmm                                  2.74%            1.304s           2.74%            1.304s           7.579ms          1.47%            1.307s           7.600ms          172              [[400, 86, 1], [400, 1, 64]]         
bmm                                  2.53%            1.203s           2.53%            1.203s           7.338ms          1.36%            1.210s           7.376ms          164              [[400, 82, 1], [400, 1, 64]]         
bmm                                  1.27%            603.404ms        1.27%            603.404ms        7.016ms          1.23%            1.089s           12.660ms         86               [[400, 64, 86], [400, 86, 1]]        
addmm                                0.14%            64.449ms         0.14%            64.449ms         66.580us         1.21%            1.078s           1.114ms          968              [[11192], [400, 128], [128, 11192],  
CudnnRnnBackward                     0.11%            51.731ms         2.33%            1.109s           1.145ms          1.14%            1.014s           1.047ms          968              [[400, 64], [], []]                  
_cudnn_rnn_backward                  2.22%            1.057s           2.22%            1.057s           1.092ms          1.14%            1.009s           1.043ms          968              [[400, 64], [99840], [4, 400, 64],   
bmm                                  1.17%            556.992ms        1.17%            556.992ms        6.793ms          1.13%            1.002s           12.214ms         82               [[400, 64, 82], [400, 82, 1]]        
LogSoftmaxBackward                   0.00%            155.561us        0.01%            2.996ms          374.467us        0.88%            781.095ms        97.637ms         8                [[400, 11192, 100]]                  
_log_softmax_backward_data           0.01%            2.840ms          0.01%            2.840ms          355.022us        0.88%            781.061ms        97.633ms         8                [[400, 11192, 100], [400, 11192, 10  
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------  
Self CPU time total: 47.656s
CUDA time total: 88.775s
mcarilli commented 4 years ago

What version of Cuda are you using? For versions earlier than 9.1, bmm was known to be slow in fp16. The current version of apex should keep it in fp32 for such versions, though.

tsdalton commented 4 years ago

Added versions above: I'm using CUDA 10.1