learning-at-home / lean_transformer

Memory-efficient transformer. Work in progress.
MIT License
19 stars 3 forks source link

Memory-efficient attention #7

Open justheuristic opened 2 years ago

justheuristic commented 2 years ago

This is a discussion of how to minimize memory usage of attention.

Current state: investigating apex's scaled_masked_softmax to check how it operates

krunt commented 2 years ago

regarding scaled_masked_softmax_cuda scaled_masked_softmax_cuda from apex/csrc/megatron behaves the same on forward pass as pytorch softmax, on backward it is inplace!!! (saving 2 buffers: tmp and return)

It supports seq_len<=2048 (I think easy to extend), float16

regarding next iteration for memory saving Implemented here https://github.com/krunt/mytorchcudamodules/blob/master/modules/mine_self_multihead_attn_func.py loop by batch dimension based on python version of multihead attention from apex.

Need to commit this code & tests to this repo The logic should be enabled by input argument flag

justheuristic commented 2 years ago

Summary based on @krunt 's recent talk about FMHA design:

The shmemory way is significantly faster (~10x on fmha benchmark #8), but requires that all keys/values fit into shared memory. As a result, both FMHA and FasterTransformer are limited by head dimension 64 and sequence length 512.

In turn, the naive way supports arbitrary head size and sequence length, but is significantly slower because it needs to store/load intermediate values in global memory.

justheuristic commented 2 years ago

Based on these two solutions, we can produce a middle-of-the-road implementation that the flexibility of naive strategy with most_of_the performance from shmemory-based strategy

Stage 1: compute log-sum-exps

for each query, compute a scalar log-sum-exp of dot products, i.e. result[i] = log(sum_over_j(<query_i, key_j>))

Log-sum-exps can be partially computed in chunks of tile_size tokens. Second, third, etc. tiles do the following:

# forall tile i = 0...num_queries/tile_size, j=0...num_keys/tile_size
logaddexp_accumulators_i = load_logsumexp_outputs_from_previous_part()  # initially 1d[tile_size] of -inf
new_log_add_exps_ij = compute_dotproduct_logsumexp(query_tiles[i], key_tiles[j])
logaddexp_accumulators_i [:]= safe_logaddexp_pair(logaddexp_accumulators_i, new_log_add_exps_ij)

Wherein compute_dotproduct_logsumexp stands for computing dot-product of queries to keys, followed by a reduce_logsumexp over all keys, parallel for each query. , safe_logaddexp_pair is an element-wise log-sum of two exponents, equivalent to torch logaddexp

i/o: load queries and keys, 2x [tile_size x head size], store logsumexps: small [tile_size] vectors flops: ~half of fusedMHA's forward pass, since we have no need

Stage 2: forward (given logsumexp)

Once we know log-sum-exps, we no longer need to load the entire set of queries into shared memory.

Instead, we can load one chunk at a time, compute partial attention outputs from that chunk, add them to the accumulator, then load the next chunk, etc.

# forall tile i = 0...num_queries/tile_size, j=0...num_keys/tile_size
query_tiles[i], key_tiles[j], value_tiles[j] = load_into_shmemory()
attention_accumulators_i = load_partial_results_from_previous_part()  # initially 2d[num_queries, head_dim] of zeros
logsumexp_accumulator_i = load_from_stage_1_for_queries_i()

dot_product_ij = dot_product(query_tiles[i], key_tiles[j])
softmax_tile_ij = exp(dot_product_ij - logsumexp_accumulator_i)
attention_output_tile_ij = dot_product(softmax_tile_ij, value_tiles[j])
attention_accumulators_i [:]= attention_accumulators_i + attention_output_tile_ij

i/o same as shmemory-based MHA, but with one extra tensor loaded flops: a bit less than shmemory-based MHA since softmax denominator is pre-computed

Stage 3: backward

Use the same backward logic as in shmemory, but this time you reuse log-sum-exps saved from the forward pass and accumulate gradients by tiles.

Notes:

krunt commented 2 years ago

image

fwd fmha for longer sequences is implemented on this fork https://github.com/krunt/apex

k,v are in smem always (no offload (!!!) to gmem during iteration by Q)

1) fwd 2x-2.5x faster than lean (and memory efficient too!). 2) not optimal support (fmha does not support it) of head_dim > 64 (hope it is correct - the results say so)

TODO:

1) support for initialization of cacc_max, cacc_sum, vacc 2) test gmem offload slowdown of cacc_max, cacc_sum 3) support different seqlen (via mask (currently not fixed in fmha - easy to do)) 4) bwd 5) fwd accumulators to float (??? is needed)

krunt commented 2 years ago

bwd is ported:

image

krunt commented 2 years ago

fwd+bwd results:

image

image

image