pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
756 stars 96 forks source link

Replace flash_4 with FlexAttention #639

Open cpuhrsch opened 1 month ago

cpuhrsch commented 1 month ago

https://github.com/pytorch-labs/segment-anything-fast/ uses custom Triton code to implement a variant of SDPA that supports the kind of additive attention required by the image_encoder.

In a nutshell the code it implements using this custom Triton kernel is

    rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
    rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
    attn_bias = (rel_h_ + rel_w_).view(q_.size(0), q_.size(1),
                                       rel_h_.size(2), rel_h_.size(3) * rel_w_.size(4))
    return torch.nn.functional.scaled_dot_product_attention(q_, k_, v_, attn_mask=attn_bias)

With the release of FlexAttention in PyTorch 2.5(code examples) it should now we possible to express this without the need for custom Triton code.

Not only will FlexAttention be able to support a fused implementations for more input shapes, it is also likely to produce more optimal code and with better hyperparameters. This kind of fused attention caused an end-to-end improvement of about 1.15x on top of a fused SDPA and torch.compile'd (with CUDA graphs) baselined.

The task:

Copy over the relevant files from segment-anything-fast into torchao's model folder and follow the readme to rerun if needed.

Write a FlexAttention version of flash_4 and measure difference in performance. If it helps, we can immediately land it in torchao, but at a minimum it could influence FlexAttention development.

tobiasvanderwerff commented 5 days ago

@cpuhrsch

I would like to give this a shot. Could you help me clarify something?

Is the goal to make a fork of segment-anything-fast that uses Flex Attention, and test that in ao? The alternative would be to manually copy over all the files from segment-anything-fast to ao/torchao/_models/sam/, but that seems overkill since the only change is in the SDPA call.

What I could do is make a fork of segment-anything-fast that uses Flex Attention and use that as an alternative pip install to pip3 install git+https://github.com/pytorch-labs/segment-anything-fast.git when benchmarking SAM.

Let me know if this makes any sense, or if you meant something else.

cpuhrsch commented 4 days ago

@tobiasvanderwerff - Yes, we could also get started with an experimental PR against https://github.com/pytorch-labs/segment-anything-fast . Eventually it could be convenient to be able to vendor the changes in SAM-fast and make them more easily accessible via torchao packaging and distribution. What do you think about this?

tobiasvanderwerff commented 4 days ago

@cpuhrsch that sounds like a plan. Let me try to get started on this in the next few days.

I already tried to run the SAM benchmark today to get started but realized that my current GPU (NVIDIA T4) does not support Flash Attention (since it requires compute capability >=sm_80, e.g. an A100). However, I intend to get access to a cloud A100 GPU instance in the next few days.

If getting access to a better GPU doesn't work out, I don't think I'll be able to work on this, and I'll let you know in that case.

tobiasvanderwerff commented 1 day ago

@cpuhrsch as discussed, I've created a fork of the segment-anything-fast repo that uses Flex Attention instead of the custom Triton kernel. I've also added a test to check for correctness. You can see the changes here.

I'm posting benchmark results from ao/torchao/_models/sam/benchmark.sh below. First results are not terribly enouraging: the Flex Attention implementation leads to a ~25% reduction in img/s. I might do some more digging to see why this is happening. If you have any suggestions, I'd love to hear them.

As a side note, Flex Attention only accepts embedding sizes that are powers of two, so I had to add padding to make it work. It's possible that the padding leads to the negative effect in performance, although the Triton kernel seems to do the same thing.

Torch version: 2.6.0.dev20240918 GPU: A100 80GB

Baseline results (using Triton kernel): device sam_model_type batch_size memory(MiB) memory(%) img_s(avg) batch_ms(avg)/batch_size mIoU use_compile use_half compress use_compile_decoder use_rel_pos pad_input_image_batch num_workers num_batches num_images profile_path memory_path
cuda vit_h 32 15172 18 22.533401716616083 44.37856354651513 0.5812715827356921 max-autotune torch.bfloat16 None False True True 32 154 4928 None None
cuda vit_h 32 15154 18 25.16516896830006 39.73746416166231 0.5818834536577897 max-autotune torch.bfloat16 int8_dynamic_quant False True True 32 154 4928 None None
cuda vit_h 32 15632 19 24.824717871078573 40.282431614863405 0.5675837487618974 max-autotune torch.bfloat16 sparse_mlp_only False True True 32 154 4928 None None
cuda vit_h 32 13429 16 24.589577947798148 40.66763578142439 0.5306639662569573 max-autotune torch.bfloat16 sparse False True True 32 154 4928 None None
cuda vit_h 32 14869 18 26.597207143088742 37.597932543073384 0.5669944616184625 max-autotune torch.bfloat16 int8_dynamic_quant_sparse False True True 32 154 4928 None None
cuda vit_h 32 17068 21 23.96093702681232 41.73459489004953 0.5485481164943489 max-autotune torch.float16 int4_weight_only_sparse False True True 32 154 4928 None None
Flex Attention results (I omitted the last two rows because running the benchmark was taking a long time): device sam_model_type batch_size memory(MiB) memory(%) img_s(avg) batch_ms(avg)/batch_size mIoU use_compile use_half compress use_compile_decoder use_rel_pos pad_input_image_batch num_workers num_batches num_images profile_path memory_path
cuda vit_h 32 19531 24 16.35339887491553 61.14936764209301 0.5812806843206303 max-autotune torch.bfloat16 None False True True 24 154 4928 None None
cuda vit_h 32 19512 24 17.72072649749095 56.43109497466644 0.5815980109018701 max-autotune torch.bfloat16 int8_dynamic_quant False True True 24 154 4928 None None
cuda vit_h 32 20960 25 16.6174344353318 60.177761127422386 0.5672995875671748 max-autotune torch.bfloat16 sparse_mlp_only False True True 24 154 4928 None None
cuda vit_h 32 18997 23 14.915692058093141 67.04348655799767 0.5306602491658978 max-autotune torch.bfloat16 sparse False True True 24 154 4928 None None
cpuhrsch commented 10 hours ago

Hm, very interesting. Thanks for doing this work. Do you mind attaching GPU traces for say the first setup both with and without flexattention?

You can gather traces using https://github.com/pytorch-labs/segment-anything-fast/tree/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/experiments#kernel-traces . Just ensure that path ends in .json.gz.

cpuhrsch commented 10 hours ago

Using the GPU traces it is also possible to annotate (using https://pytorch.org/docs/main/generated/torch.autograd.profiler.record_function.html#record-function and https://pytorch.org/docs/main/generated/torch.cuda.synchronize.html#torch-cuda-synchronize ) the section that was changed and look at the GPU kernel difference in runtime only. This way we can double check the slowdown is precisely due to this change.

I'd create two versions of these traces, one with annotation and sync and one without. So that means 4 traces in total

a) Baseline without annotate b) Baseline with annotate c) Changed without annotate d) Changed with annotate