Open justheuristic opened 2 years ago
regarding scaled_masked_softmax_cuda scaled_masked_softmax_cuda from apex/csrc/megatron behaves the same on forward pass as pytorch softmax, on backward it is inplace!!! (saving 2 buffers: tmp and return)
It supports seq_len<=2048 (I think easy to extend), float16
regarding next iteration for memory saving Implemented here https://github.com/krunt/mytorchcudamodules/blob/master/modules/mine_self_multihead_attn_func.py loop by batch dimension based on python version of multihead attention from apex.
Need to commit this code & tests to this repo The logic should be enabled by input argument flag
Summary based on @krunt 's recent talk about FMHA design:
The shmemory way is significantly faster (~10x on fmha benchmark #8), but requires that all keys/values fit into shared memory. As a result, both FMHA and FasterTransformer are limited by head dimension 64 and sequence length 512.
In turn, the naive way supports arbitrary head size and sequence length, but is significantly slower because it needs to store/load intermediate values in global memory.
Based on these two solutions, we can produce a middle-of-the-road implementation that the flexibility of naive strategy with most_of_the performance from shmemory-based strategy
for each query, compute a scalar log-sum-exp of dot products, i.e.
result[i] = log(sum_over_j(<query_i, key_j>))
Log-sum-exps can be partially computed in chunks of tile_size
tokens.
Second, third, etc. tiles do the following:
# forall tile i = 0...num_queries/tile_size, j=0...num_keys/tile_size
logaddexp_accumulators_i = load_logsumexp_outputs_from_previous_part() # initially 1d[tile_size] of -inf
new_log_add_exps_ij = compute_dotproduct_logsumexp(query_tiles[i], key_tiles[j])
logaddexp_accumulators_i [:]= safe_logaddexp_pair(logaddexp_accumulators_i, new_log_add_exps_ij)
Wherein compute_dotproduct_logsumexp
stands for computing dot-product of queries to keys, followed by a reduce_logsumexp over all keys, parallel for each query.
, safe_logaddexp_pair is an element-wise log-sum of two exponents, equivalent to torch logaddexp
i/o: load queries and keys, 2x [tile_size x head size], store logsumexps: small [tile_size] vectors
flops: ~half of fusedMHA's forward pass, since we have no need
Once we know log-sum-exps, we no longer need to load the entire set of queries into shared memory.
Instead, we can load one chunk at a time, compute partial attention outputs from that chunk, add them to the accumulator, then load the next chunk, etc.
# forall tile i = 0...num_queries/tile_size, j=0...num_keys/tile_size
query_tiles[i], key_tiles[j], value_tiles[j] = load_into_shmemory()
attention_accumulators_i = load_partial_results_from_previous_part() # initially 2d[num_queries, head_dim] of zeros
logsumexp_accumulator_i = load_from_stage_1_for_queries_i()
dot_product_ij = dot_product(query_tiles[i], key_tiles[j])
softmax_tile_ij = exp(dot_product_ij - logsumexp_accumulator_i)
attention_output_tile_ij = dot_product(softmax_tile_ij, value_tiles[j])
attention_accumulators_i [:]= attention_accumulators_i + attention_output_tile_ij
i/o same as shmemory-based MHA, but with one extra tensor loaded flops: a bit less than shmemory-based MHA since softmax denominator is pre-computed
Use the same backward logic as in shmemory, but this time you reuse log-sum-exps saved from the forward pass and accumulate gradients by tiles.
fwd fmha for longer sequences is implemented on this fork https://github.com/krunt/apex
k,v are in smem always (no offload (!!!) to gmem during iteration by Q)
1) fwd 2x-2.5x faster than lean (and memory efficient too!). 2) not optimal support (fmha does not support it) of head_dim > 64 (hope it is correct - the results say so)
TODO:
1) support for initialization of cacc_max, cacc_sum, vacc 2) test gmem offload slowdown of cacc_max, cacc_sum 3) support different seqlen (via mask (currently not fixed in fmha - easy to do)) 4) bwd 5) fwd accumulators to float (??? is needed)
bwd is ported:
fwd+bwd results:
This is a discussion of how to minimize memory usage of attention.
Current state: investigating apex's scaled_masked_softmax to check how it operates