Closed foreverpiano closed 3 months ago
if : torch.cuda.synchronize() before_time = time.perf_counter() with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True): hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) torch.cuda.synchronize() end_time = time.perf_counter() print("xformer: ", end_time - before_time) # [bs, head, seq_len, dim] * [bs, head, seq_len, dim] else: print("flex:", layer_index) def expand_to_128(tensor): padding_size = 128 - tensor.size(-1) return torch.nn.functional.pad(tensor, (0, padding_size)) query_expanded = expand_to_128(query) key_expanded = expand_to_128(key) value_expanded = expand_to_128(value) @lru_cache def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"): block_mask = create_block_mask(score_mod, B, H, M, N, device=device) return block_mask def noop(score, b, h, q_idx, kv_idx): return score print(query_expanded.shape, key_expanded.shape, value_expanded.shape) torch.cuda.synchronize() before_time = time.perf_counter() block_mask = create_block_mask_cached(prefix_lm_causal_mask, 1, 1, seq_len, seq_len) hidden_states = flex_attention(query_expanded, key_expanded, value_expanded, block_mask=block_mask, scale=1./math.sqrt(d_k)) del block_mask torch.cuda.synchronize() end_time = time.perf_counter() print("flex_attn: ", end_time - before_time) def shrink_to_96(tensor): return tensor[..., :96] hidden_states = shrink_to_96(hidden_states)
part of attention code of real case the result is xformer: 0.10 ( match the table) flex: 0.50 (10x slower than the table) @drisspg
so do you know why?
@NonvolatileMemory padding issue. FA2 tests with dim=96. Flex tests with dim=128. It's unfair comparison
part of attention code of real case the result is xformer: 0.10 ( match the table) flex: 0.50 (10x slower than the table) @drisspg