-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory-efficient attention #7
Comments
regarding scaled_masked_softmax_cuda It supports seq_len<=2048 (I think easy to extend), float16 regarding next iteration for memory saving Need to commit this code & tests to this repo |
Summary based on @krunt 's recent talk about FMHA design:
The shmemory way is significantly faster (~10x on fmha benchmark #8), but requires that all keys/values fit into shared memory. As a result, both FMHA and FasterTransformer are limited by head dimension 64 and sequence length 512. In turn, the naive way supports arbitrary head size and sequence length, but is significantly slower because it needs to store/load intermediate values in global memory. |
Based on these two solutions, we can produce a middle-of-the-road implementation that the flexibility of naive strategy with most_of_the performance from shmemory-based strategy Stage 1: compute log-sum-expsfor each query, compute a scalar log-sum-exp of dot products, i.e. Log-sum-exps can be partially computed in chunks of # forall tile i = 0...num_queries/tile_size, j=0...num_keys/tile_size
logaddexp_accumulators_i = load_logsumexp_outputs_from_previous_part() # initially 1d[tile_size] of -inf
new_log_add_exps_ij = compute_dotproduct_logsumexp(query_tiles[i], key_tiles[j])
logaddexp_accumulators_i [:]= safe_logaddexp_pair(logaddexp_accumulators_i, new_log_add_exps_ij) Wherein i/o: Stage 2: forward (given logsumexp)Once we know log-sum-exps, we no longer need to load the entire set of queries into shared memory. Instead, we can load one chunk at a time, compute partial attention outputs from that chunk, add them to the accumulator, then load the next chunk, etc. # forall tile i = 0...num_queries/tile_size, j=0...num_keys/tile_size
query_tiles[i], key_tiles[j], value_tiles[j] = load_into_shmemory()
attention_accumulators_i = load_partial_results_from_previous_part() # initially 2d[num_queries, head_dim] of zeros
logsumexp_accumulator_i = load_from_stage_1_for_queries_i()
dot_product_ij = dot_product(query_tiles[i], key_tiles[j])
softmax_tile_ij = exp(dot_product_ij - logsumexp_accumulator_i)
attention_output_tile_ij = dot_product(softmax_tile_ij, value_tiles[j])
attention_accumulators_i [:]= attention_accumulators_i + attention_output_tile_ij i/o same as shmemory-based MHA, but with one extra tensor loaded Stage 3: backwardUse the same backward logic as in shmemory, but this time you reuse log-sum-exps saved from the forward pass and accumulate gradients by tiles. Notes:
|
fwd fmha for longer sequences is implemented on this fork https://github.com/krunt/apex k,v are in smem always (no offload (!!!) to gmem during iteration by Q)
TODO:
|
This is a discussion of how to minimize memory usage of attention.
Current state: investigating apex's scaled_masked_softmax to check how it operates
The text was updated successfully, but these errors were encountered: