Open cpuhrsch opened 1 month ago
@cpuhrsch
I would like to give this a shot. Could you help me clarify something?
Is the goal to make a fork of segment-anything-fast that uses Flex Attention, and test that in ao
? The alternative would be to manually copy over all the files from segment-anything-fast to ao/torchao/_models/sam/
, but that seems overkill since the only change is in the SDPA call.
What I could do is make a fork of segment-anything-fast that uses Flex Attention and use that as an alternative pip install to pip3 install git+https://github.com/pytorch-labs/segment-anything-fast.git
when benchmarking SAM.
Let me know if this makes any sense, or if you meant something else.
@tobiasvanderwerff - Yes, we could also get started with an experimental PR against https://github.com/pytorch-labs/segment-anything-fast . Eventually it could be convenient to be able to vendor the changes in SAM-fast and make them more easily accessible via torchao packaging and distribution. What do you think about this?
@cpuhrsch that sounds like a plan. Let me try to get started on this in the next few days.
I already tried to run the SAM benchmark today to get started but realized that my current GPU (NVIDIA T4) does not support Flash Attention (since it requires compute capability >=sm_80, e.g. an A100). However, I intend to get access to a cloud A100 GPU instance in the next few days.
If getting access to a better GPU doesn't work out, I don't think I'll be able to work on this, and I'll let you know in that case.
@cpuhrsch as discussed, I've created a fork of the segment-anything-fast repo that uses Flex Attention instead of the custom Triton kernel. I've also added a test to check for correctness. You can see the changes here.
I'm posting benchmark results from ao/torchao/_models/sam/benchmark.sh
below. First results are not terribly enouraging: the Flex Attention implementation leads to a ~25% reduction in img/s. I might do some more digging to see why this is happening. If you have any suggestions, I'd love to hear them.
As a side note, Flex Attention only accepts embedding sizes that are powers of two, so I had to add padding to make it work. It's possible that the padding leads to the negative effect in performance, although the Triton kernel seems to do the same thing.
Torch version: 2.6.0.dev20240918
GPU: A100 80GB
Baseline results (using Triton kernel): | device | sam_model_type | batch_size | memory(MiB) | memory(%) | img_s(avg) | batch_ms(avg)/batch_size | mIoU | use_compile | use_half | compress | use_compile_decoder | use_rel_pos | pad_input_image_batch | num_workers | num_batches | num_images | profile_path | memory_path |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
cuda | vit_h | 32 | 15172 | 18 | 22.533401716616083 | 44.37856354651513 | 0.5812715827356921 | max-autotune | torch.bfloat16 | None | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 15154 | 18 | 25.16516896830006 | 39.73746416166231 | 0.5818834536577897 | max-autotune | torch.bfloat16 | int8_dynamic_quant | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 15632 | 19 | 24.824717871078573 | 40.282431614863405 | 0.5675837487618974 | max-autotune | torch.bfloat16 | sparse_mlp_only | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 13429 | 16 | 24.589577947798148 | 40.66763578142439 | 0.5306639662569573 | max-autotune | torch.bfloat16 | sparse | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 14869 | 18 | 26.597207143088742 | 37.597932543073384 | 0.5669944616184625 | max-autotune | torch.bfloat16 | int8_dynamic_quant_sparse | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 17068 | 21 | 23.96093702681232 | 41.73459489004953 | 0.5485481164943489 | max-autotune | torch.float16 | int4_weight_only_sparse | False | True | True | 32 | 154 | 4928 | None | None |
Flex Attention results (I omitted the last two rows because running the benchmark was taking a long time): | device | sam_model_type | batch_size | memory(MiB) | memory(%) | img_s(avg) | batch_ms(avg)/batch_size | mIoU | use_compile | use_half | compress | use_compile_decoder | use_rel_pos | pad_input_image_batch | num_workers | num_batches | num_images | profile_path | memory_path |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
cuda | vit_h | 32 | 19531 | 24 | 16.35339887491553 | 61.14936764209301 | 0.5812806843206303 | max-autotune | torch.bfloat16 | None | False | True | True | 24 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 19512 | 24 | 17.72072649749095 | 56.43109497466644 | 0.5815980109018701 | max-autotune | torch.bfloat16 | int8_dynamic_quant | False | True | True | 24 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 20960 | 25 | 16.6174344353318 | 60.177761127422386 | 0.5672995875671748 | max-autotune | torch.bfloat16 | sparse_mlp_only | False | True | True | 24 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 18997 | 23 | 14.915692058093141 | 67.04348655799767 | 0.5306602491658978 | max-autotune | torch.bfloat16 | sparse | False | True | True | 24 | 154 | 4928 | None | None |
Hm, very interesting. Thanks for doing this work. Do you mind attaching GPU traces for say the first setup both with and without flexattention?
You can gather traces using https://github.com/pytorch-labs/segment-anything-fast/tree/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/experiments#kernel-traces . Just ensure that path
ends in .json.gz
.
Using the GPU traces it is also possible to annotate (using https://pytorch.org/docs/main/generated/torch.autograd.profiler.record_function.html#record-function and https://pytorch.org/docs/main/generated/torch.cuda.synchronize.html#torch-cuda-synchronize ) the section that was changed and look at the GPU kernel difference in runtime only. This way we can double check the slowdown is precisely due to this change.
I'd create two versions of these traces, one with annotation and sync and one without. So that means 4 traces in total
a) Baseline without annotate b) Baseline with annotate c) Changed without annotate d) Changed with annotate
https://github.com/pytorch-labs/segment-anything-fast/ uses custom Triton code to implement a variant of SDPA that supports the kind of additive attention required by the image_encoder.
In a nutshell the code it implements using this custom Triton kernel is
With the release of FlexAttention in PyTorch 2.5(code examples) it should now we possible to express this without the need for custom Triton code.
Not only will FlexAttention be able to support a fused implementations for more input shapes, it is also likely to produce more optimal code and with better hyperparameters. This kind of fused attention caused an end-to-end improvement of about 1.15x on top of a fused SDPA and torch.compile'd (with CUDA graphs) baselined.
The task:
Copy over the relevant files from segment-anything-fast into torchao's model folder and follow the readme to rerun if needed.
Write a FlexAttention version of flash_4 and measure difference in performance. If it helps, we can immediately land it in torchao, but at a minimum it could influence FlexAttention development.