pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
484 stars 24 forks source link

why in the attention-gym, flex-attention runs faster than FA2; however, in real environment, it runs too slower than FA2? #27

Closed foreverpiano closed 3 months ago

foreverpiano commented 3 months ago

图片

if :
    torch.cuda.synchronize()
    before_time = time.perf_counter()
    with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

    torch.cuda.synchronize()
    end_time = time.perf_counter()
    print("xformer: ", end_time - before_time)
    # [bs, head, seq_len, dim] * [bs, head, seq_len, dim]

else:
    print("flex:", layer_index)

    def expand_to_128(tensor):
        padding_size = 128 - tensor.size(-1)
        return torch.nn.functional.pad(tensor, (0, padding_size))

    query_expanded = expand_to_128(query)
    key_expanded = expand_to_128(key)
    value_expanded = expand_to_128(value)

    @lru_cache
    def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
        block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
        return block_mask

    def noop(score, b, h, q_idx, kv_idx):
        return score
    print(query_expanded.shape, key_expanded.shape, value_expanded.shape)

    torch.cuda.synchronize()
    before_time = time.perf_counter()
    block_mask = create_block_mask_cached(prefix_lm_causal_mask, 1, 1, seq_len, seq_len)

    hidden_states = flex_attention(query_expanded, key_expanded, value_expanded, block_mask=block_mask, scale=1./math.sqrt(d_k))

    del block_mask
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    print("flex_attn: ", end_time - before_time)                        

    def shrink_to_96(tensor):
        return tensor[..., :96]

    hidden_states = shrink_to_96(hidden_states)

part of attention code of real case the result is xformer: 0.10 ( match the table) flex: 0.50 (10x slower than the table) @drisspg

NonvolatileMemory commented 2 months ago

so do you know why?

foreverpiano commented 2 months ago

@NonvolatileMemory padding issue. FA2 tests with dim=96. Flex tests with dim=128. It's unfair comparison