dmlc / gluon-nlp

NLP made easy
https://nlp.gluon.ai/
Apache License 2.0
2.56k stars 538 forks source link

Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR #1258

Closed sxjscience closed 4 years ago

sxjscience commented 4 years ago

@MoisesHer @dmlc/gluon-nlp-team

This PR adds two additional flags to backbone models to enhance the computational speed and usability.

The technical insights about why layouts may matter lies as follows (also documented here https://github.com/dmlc/gluon-nlp/blob/cd48efdd51cf26b47791c0329e1092fb5fc658f7/src/gluonnlp/attention_cell.py#L540-L546):

When the layouts of memory and query are "TNC", they have the shape:

One step in the AttentionCell is to obtain the multi-head attention scores, which can be written as follows:

(L_query, B, N, C_Q) X (L_mem, B, N, C_Q) -> (B, N, L_query, L_mem)

This layout structure can be implemented very efficiently because B, N are consecutive to each other.

To have a clear picture of what's happening, we may consider the (i, j)th element of the output, i.e.,

out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T

This is just one GEMM call. We can thus implement the whole kernel via a single call of batched GEMM by correctly specifying the strides.

Also, in fairseq, the inner computation of the TransformerEncoder is using the TN layout:

https://github.com/pytorch/fairseq/blob/108bb2560b1ec01524ba723bc7c69186875afa0a/fairseq/models/transformer.py#L399-L407

After this PR, these models will have the layout flag:

codecov[bot] commented 4 years ago

Codecov Report

Merging #1258 into numpy will increase coverage by 1.35%. The diff coverage is 89.17%.

Impacted file tree graph

@@            Coverage Diff             @@
##            numpy    #1258      +/-   ##
==========================================
+ Coverage   82.56%   83.92%   +1.35%     
==========================================
  Files          38       41       +3     
  Lines        5534     6157     +623     
==========================================
+ Hits         4569     5167     +598     
- Misses        965      990      +25     
Impacted Files Coverage Δ
setup.py 0.00% <ø> (ø)
src/gluonnlp/models/transformer_xl.py 82.52% <66.66%> (-0.20%) :arrow_down:
src/gluonnlp/models/xlmr.py 88.23% <70.00%> (+1.35%) :arrow_up:
src/gluonnlp/models/mobilebert.py 87.72% <81.67%> (+6.37%) :arrow_up:
src/gluonnlp/models/electra.py 80.78% <84.68%> (+1.83%) :arrow_up:
src/gluonnlp/attention_cell.py 79.52% <86.66%> (-0.39%) :arrow_down:
src/gluonnlp/models/roberta.py 93.67% <87.50%> (+4.89%) :arrow_up:
src/gluonnlp/models/bert.py 94.86% <92.96%> (+9.93%) :arrow_up:
src/gluonnlp/utils/testing.py 94.11% <93.47%> (-3.03%) :arrow_down:
src/gluonnlp/models/albert.py 95.47% <93.80%> (-1.22%) :arrow_down:
... and 17 more
sxjscience commented 4 years ago

Now I'm waiting for https://github.com/dmlc/gluon-nlp/pull/1261. We added two flags:

sxjscience commented 4 years ago

@dmlc/gluon-nlp-committers @MoisesHer @ZheyuYe Should be ready for review.