Open cpuhrsch opened 3 months ago
@cpuhrsch
I would like to give this a shot. Could you help me clarify something?
Is the goal to make a fork of segment-anything-fast that uses Flex Attention, and test that in ao
? The alternative would be to manually copy over all the files from segment-anything-fast to ao/torchao/_models/sam/
, but that seems overkill since the only change is in the SDPA call.
What I could do is make a fork of segment-anything-fast that uses Flex Attention and use that as an alternative pip install to pip3 install git+https://github.com/pytorch-labs/segment-anything-fast.git
when benchmarking SAM.
Let me know if this makes any sense, or if you meant something else.
@tobiasvanderwerff - Yes, we could also get started with an experimental PR against https://github.com/pytorch-labs/segment-anything-fast . Eventually it could be convenient to be able to vendor the changes in SAM-fast and make them more easily accessible via torchao packaging and distribution. What do you think about this?
@cpuhrsch that sounds like a plan. Let me try to get started on this in the next few days.
I already tried to run the SAM benchmark today to get started but realized that my current GPU (NVIDIA T4) does not support Flash Attention (since it requires compute capability >=sm_80, e.g. an A100). However, I intend to get access to a cloud A100 GPU instance in the next few days.
If getting access to a better GPU doesn't work out, I don't think I'll be able to work on this, and I'll let you know in that case.
@cpuhrsch as discussed, I've created a fork of the segment-anything-fast repo that uses Flex Attention instead of the custom Triton kernel. I've also added a test to check for correctness. You can see the changes here.
I'm posting benchmark results from ao/torchao/_models/sam/benchmark.sh
below. First results are not terribly enouraging: the Flex Attention implementation leads to a ~25% reduction in img/s. I might do some more digging to see why this is happening. If you have any suggestions, I'd love to hear them.
As a side note, Flex Attention only accepts embedding sizes that are powers of two, so I had to add padding to make it work. It's possible that the padding leads to the negative effect in performance, although the Triton kernel seems to do the same thing.
Torch version: 2.6.0.dev20240918
GPU: A100 80GB
Baseline results (using Triton kernel): | device | sam_model_type | batch_size | memory(MiB) | memory(%) | img_s(avg) | batch_ms(avg)/batch_size | mIoU | use_compile | use_half | compress | use_compile_decoder | use_rel_pos | pad_input_image_batch | num_workers | num_batches | num_images | profile_path | memory_path |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
cuda | vit_h | 32 | 15172 | 18 | 22.533401716616083 | 44.37856354651513 | 0.5812715827356921 | max-autotune | torch.bfloat16 | None | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 15154 | 18 | 25.16516896830006 | 39.73746416166231 | 0.5818834536577897 | max-autotune | torch.bfloat16 | int8_dynamic_quant | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 15632 | 19 | 24.824717871078573 | 40.282431614863405 | 0.5675837487618974 | max-autotune | torch.bfloat16 | sparse_mlp_only | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 13429 | 16 | 24.589577947798148 | 40.66763578142439 | 0.5306639662569573 | max-autotune | torch.bfloat16 | sparse | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 14869 | 18 | 26.597207143088742 | 37.597932543073384 | 0.5669944616184625 | max-autotune | torch.bfloat16 | int8_dynamic_quant_sparse | False | True | True | 32 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 17068 | 21 | 23.96093702681232 | 41.73459489004953 | 0.5485481164943489 | max-autotune | torch.float16 | int4_weight_only_sparse | False | True | True | 32 | 154 | 4928 | None | None |
Flex Attention results (I omitted the last two rows because running the benchmark was taking a long time): | device | sam_model_type | batch_size | memory(MiB) | memory(%) | img_s(avg) | batch_ms(avg)/batch_size | mIoU | use_compile | use_half | compress | use_compile_decoder | use_rel_pos | pad_input_image_batch | num_workers | num_batches | num_images | profile_path | memory_path |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
cuda | vit_h | 32 | 19531 | 24 | 16.35339887491553 | 61.14936764209301 | 0.5812806843206303 | max-autotune | torch.bfloat16 | None | False | True | True | 24 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 19512 | 24 | 17.72072649749095 | 56.43109497466644 | 0.5815980109018701 | max-autotune | torch.bfloat16 | int8_dynamic_quant | False | True | True | 24 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 20960 | 25 | 16.6174344353318 | 60.177761127422386 | 0.5672995875671748 | max-autotune | torch.bfloat16 | sparse_mlp_only | False | True | True | 24 | 154 | 4928 | None | None | |
cuda | vit_h | 32 | 18997 | 23 | 14.915692058093141 | 67.04348655799767 | 0.5306602491658978 | max-autotune | torch.bfloat16 | sparse | False | True | True | 24 | 154 | 4928 | None | None |
Hm, very interesting. Thanks for doing this work. Do you mind attaching GPU traces for say the first setup both with and without flexattention?
You can gather traces using https://github.com/pytorch-labs/segment-anything-fast/tree/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/experiments#kernel-traces . Just ensure that path
ends in .json.gz
.
Using the GPU traces it is also possible to annotate (using https://pytorch.org/docs/main/generated/torch.autograd.profiler.record_function.html#record-function and https://pytorch.org/docs/main/generated/torch.cuda.synchronize.html#torch-cuda-synchronize ) the section that was changed and look at the GPU kernel difference in runtime only. This way we can double check the slowdown is precisely due to this change.
I'd create two versions of these traces, one with annotation and sync and one without. So that means 4 traces in total
a) Baseline without annotate b) Baseline with annotate c) Changed without annotate d) Changed with annotate
Tracing results indicate that in the Flex Attention version, a lot of time is spent on a padding kernel (triton_tem_fused_constant_pad_nd_38
, indicated by blue arrows in the screenshot below):
The trace shows that the Flex Attention impl. spends 2 seconds in the image encoder, whereas the baseline spends only 1.35 seconds. So it definitely looks like quite a slowdown in the part of the code where SDPA is used.
Padding does not seem to take nearly as much time in the baseline (in the trace, the largest purple blocks under the image encoder
block are calls to _fwd_kernel_aligned
, the top level attention function):
So it seems that the padding is a large source of the slowdown. As I mentioned earlier, the Triton kernel does the same padding, but they somehow have made it more efficient. At the top of the function, it says:
"""
Writing this as a composite allows torch.compile to fuse
needed padding into previous operations and memory
allocations.
"""
So it seems like they somehow manage to make the padding more efficient by fusing it into earlier operations. I'm currently trying to figure out if this can also be done for the Flex Attention kernel, but it's not obvious to me how.
(NB: I also tried running the tracing with the annotations, as you suggested @cpuhrsch, but this did not seem to show up in the trace output - perhaps because of torch.compile?)
@cpuhrsch - Hm, the way you're using FlexAttention it should also be a composite (as in flex_attention_fwd
is a composite just like _attention_rel_h_rel_w
, because it's composed of multiple functions as opposed to just a single kernel).
Since this is needed specifically for vit_h, does it mean for vit_b the gap narrows or even with FlexAttention it's faster?
Also cc @Chillee and @drisspg
@cpuhrsch
vit_b
results show a similar gap between the baseline and Flex Attention. So even without padding, there is still a large diff in runtime!
Baseline: | device | sam_model_type | batch_size | memory(MiB) | memory(%) | img_s(avg) | batch_ms(avg)/batch_size | mIoU | use_compile | use_half | compress | use_compile_decoder | use_rel_pos | pad_input_image_batch | num_workers | num_batches | num_images | profile_path | memory_path |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
cuda | vit_b | 32 | 6631 | 8 | 87.1522144531224 | 11.47417775067416 | 0.5358536312719586 | max-autotune | torch.bfloat16 | None | False | True | True | 24 | 154 | 4928 | None | None |
Flex Attention: | device | sam_model_type | batch_size | memory(MiB) | memory(%) | img_s(avg) | batch_ms(avg)/batch_size | mIoU | use_compile | use_half | compress | use_compile_decoder | use_rel_pos | pad_input_image_batch | num_workers | num_batches | num_images | profile_path | memory_path |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
cuda | vit_b | 32 | 6969 | 8 | 38.57345144045247 | 25.92456631846242 | 0.536104508681229 | max-autotune | torch.bfloat16 | None | False | True | True | 24 | 154 | 4928 | None | Noned |
I looked at the profile traces but it is difficult to extract any useful information. Most of the kernels in the Flex Attention version have indiscriminate names like triton_tem_fused_13
or triton_tem_fused_31
, so it's hard to know what exactly the GPU is spending its time on.
I may have found a clue as to where the performance bottleneck lies. Replacing this line in the score_mod
function:
attn_bias = self.rel_h[batch, head, q_idx, h_idx] + self.rel_w[batch, head, q_idx, w_idx]
with this:
attn_bias = h_idx + w_idx
leads to a massive speedup (38 img/s -> 97 img/s). So it seems that the indexing into rel_h
and rel_w
is slowing things down a lot.
Unfortunately, using rel_h
and rel_w
in a different way (like passing them to the function without setting them as class attributes), leads to Torch Inductor errors when torch compiling. I've reached a point where I'm really not sure how to deal with this, so I've opened an issue in the Flex Attention repo that reproduces the issue. Hopefully, the Flex Attention authors can provide some more clarity.
Great, thank you for the investigation @tobiasvanderwerff !
@tobiasvanderwerff - For what it's worth, indexing into the rel_h
and rel_w
Tensors efficiently is a key reason why flash_4 can provide a speedup over SDPA to begin with. It's not a better implementation of SDPA, it just avoids the materialization of (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
.
@cpuhrsch an update:
I've tried the fix pushed by @Chillee, but unfortunately I still get an error (see output below). It looks like the minified code sample I referred to in the issue does not quite transfer to the more complicated setup of the SAM-fast model. I'm not really sure how to resolve this right now, and unfortunately it is not very feasible for me to keep using an A100 for testing due to expenses (sorry). So the best strategy may be to put this on hold right now and perhaps wait until FlexAttention manages this issue at some point.
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'get_stride'
target: flex_attention
args[0]: TensorBox(
View(
View(
SliceView(
View(
StorageBox(
ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.bfloat16, size=[3, 800, 12, 196, 64], stride=[120422400, 150528, 12544, 64, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
def inner_fn(index):
i0, i1, i2, i3, i4 = index
tmp0 = ops.load(buf6, i4 + 64 * i2 + 768 * i0 + 2304 * ModularIndexing(i3, 1, 14) + 32256 * ModularIndexing(i3, 14, 14) + 451584 * i1)
tmp1 = ops.load(arg7_1, i4 + 64 * i2 + 768 * i0)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[3, 800, 12, 196, 64],
origin_node=clone_2,
origins=OrderedSet([clone_2])
))
),
size=[3, 9600, 196, 64],
reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 12, 800), ModularIndexing(i1, 1, 12), i2, i3],
origins=OrderedSet([view_5, clone_2])
),
size=[1, 9600, 196, 64],
reindex=lambda i0, i1, i2, i3: [i0, i1, i2, i3],
origins=OrderedSet([unbind])
),
size=[9600, 196, 64],
reindex=lambda i0, i1, i2: [0, i0, i1, i2],
origins=OrderedSet([unbind])
),
size=[800, 12, 196, 64],
reindex=lambda i0, i1, i2, i3: [12*i0 + i1, i2, i3],
origins=OrderedSet([view_17])
)
)
args[1]: TensorBox(
View(
View(
SliceView(
View(
StorageBox(
ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.bfloat16, size=[3, 800, 12, 196, 64], stride=[120422400, 150528, 12544, 64, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
def inner_fn(index):
i0, i1, i2, i3, i4 = index
tmp0 = ops.load(buf6, i4 + 64 * i2 + 768 * i0 + 2304 * ModularIndexing(i3, 1, 14) + 32256 * ModularIndexing(i3, 14, 14) + 451584 * i1)
tmp1 = ops.load(arg7_1, i4 + 64 * i2 + 768 * i0)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[3, 800, 12, 196, 64],
origin_node=clone_2,
origins=OrderedSet([clone_2])
))
),
size=[3, 9600, 196, 64],
reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 12, 800), ModularIndexing(i1, 1, 12), i2, i3],
origins=OrderedSet([view_5, clone_2])
),
size=[1, 9600, 196, 64],
reindex=lambda i0, i1, i2, i3: [i0 + 1, i1, i2, i3],
origins=OrderedSet([unbind])
),
size=[9600, 196, 64],
reindex=lambda i0, i1, i2: [0, i0, i1, i2],
origins=OrderedSet([unbind])
),
size=[800, 12, 196, 64],
reindex=lambda i0, i1, i2, i3: [12*i0 + i1, i2, i3],
origins=OrderedSet([view_18])
)
)
args[2]: TensorBox(
View(
View(
SliceView(
View(
StorageBox(
ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.bfloat16, size=[3, 800, 12, 196, 64], stride=[120422400, 150528, 12544, 64, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
def inner_fn(index):
i0, i1, i2, i3, i4 = index
tmp0 = ops.load(buf6, i4 + 64 * i2 + 768 * i0 + 2304 * ModularIndexing(i3, 1, 14) + 32256 * ModularIndexing(i3, 14, 14) + 451584 * i1)
tmp1 = ops.load(arg7_1, i4 + 64 * i2 + 768 * i0)
tmp2 = tmp0 + tmp1
return tmp2
,
ranges=[3, 800, 12, 196, 64],
origin_node=clone_2,
origins=OrderedSet([clone_2])
))
),
size=[3, 9600, 196, 64],
reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 12, 800), ModularIndexing(i1, 1, 12), i2, i3],
origins=OrderedSet([view_5, clone_2])
),
size=[1, 9600, 196, 64],
reindex=lambda i0, i1, i2, i3: [i0 + 2, i1, i2, i3],
origins=OrderedSet([unbind])
),
size=[9600, 196, 64],
reindex=lambda i0, i1, i2: [0, i0, i1, i2],
origins=OrderedSet([unbind])
),
size=[800, 12, 196, 64],
reindex=lambda i0, i1, i2, i3: [12*i0 + i1, i2, i3],
origins=OrderedSet([view_19])
)
)
args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
args[4]: (TensorBox(StorageBox(
ComputedBuffer(name='buf15', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def inner_fn(index):
_, _, _ = index
tmp0 = ops.constant(1, torch.int32)
return tmp0
,
ranges=[1, 1, 1],
origin_node=full,
origins=OrderedSet([full])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf16', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def inner_fn(index):
_, _, _, _ = index
tmp0 = ops.constant(0, torch.int32)
return tmp0
,
ranges=[1, 1, 1, 1],
origin_node=full_default,
origins=OrderedSet([full_default])
))
)), None, None, TensorBox(StorageBox(
ComputedBuffer(name='buf17', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def inner_fn(index):
_, _, _ = index
tmp0 = ops.load(buf7, 0)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1],
origin_node=convert_element_type_11,
origins=OrderedSet([convert_element_type_11, sum_1])
))
)), TensorBox(StorageBox(
ComputedBuffer(name='buf18', layout=FlexibleLayout('cuda', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
'cuda',
torch.int32,
def inner_fn(index):
_, _, _, _ = index
tmp0 = ops.index_expr(0, dtype=torch.int16)
tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)
tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
return tmp2
,
ranges=[1, 1, 1, 1],
origin_node=convert_element_type_12,
origins=OrderedSet([convert_element_type_12, sort])
))
)), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
args[5]: 0.125
args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': False}
args[7]: (TensorBox(
View(
View(
View(
StorageBox(
Pointwise(
'cuda',
torch.bfloat16,
def inner_fn(index):
i0, i1, i2, i3, _ = index
tmp0 = ops.load(buf11, i3 + 16 * i2 + 224 * i0 + 2150400 * i1)
return tmp0
,
ranges=[9600, 14, 14, 14, 1],
origin_node=clone_4,
origins=OrderedSet([clone_4])
)
),
size=[9600, 196, 14, 1],
reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 14, 14), ModularIndexing(i1, 1, 14), i2, 0],
origins=OrderedSet([view_15, clone_4])
),
size=[800, 12, 196, 14, 1],
reindex=lambda i0, i1, i2, i3, i4: [12*i0 + i1, i2, i3, 0],
origins=OrderedSet([view_20])
),
size=[800, 12, 196, 14],
reindex=lambda i0, i1, i2, i3: [i0, i1, i2, i3, 0],
origins=OrderedSet([squeeze])
)
), TensorBox(
View(
View(
View(
StorageBox(
Pointwise(
'cuda',
torch.bfloat16,
def inner_fn(index):
i0, i1, i2, _, i4 = index
tmp0 = ops.load(buf14, i4 + 16 * i1 + 224 * i0 + 2150400 * i2)
return tmp0
,
ranges=[9600, 14, 14, 1, 14],
origin_node=clone_5,
origins=OrderedSet([clone_5])
)
),
size=[9600, 196, 1, 14],
reindex=lambda i0, i1, i2, i3: [i0, ModularIndexing(i1, 14, 14), ModularIndexing(i1, 1, 14), 0, i3],
origins=OrderedSet([clone_5, view_16])
),
size=[800, 12, 196, 1, 14],
reindex=lambda i0, i1, i2, i3, i4: [12*i0 + i1, i2, 0, i4],
origins=OrderedSet([view_21])
),
size=[800, 12, 196, 14],
reindex=lambda i0, i1, i2, i3: [i0, i1, i2, 0, i3],
origins=OrderedSet([squeeze_1])
)
))
args[8]: ()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
@tobiasvanderwerff - Thank you for testing this. I'll update https://github.com/pytorch-labs/attention-gym/issues/45 as well. At least with the most recent fix we're one step closer.
https://github.com/pytorch-labs/segment-anything-fast/ uses custom Triton code to implement a variant of SDPA that supports the kind of additive attention required by the image_encoder.
In a nutshell the code it implements using this custom Triton kernel is
With the release of FlexAttention in PyTorch 2.5(code examples) it should now we possible to express this without the need for custom Triton code.
Not only will FlexAttention be able to support a fused implementations for more input shapes, it is also likely to produce more optimal code and with better hyperparameters. This kind of fused attention caused an end-to-end improvement of about 1.15x on top of a fused SDPA and torch.compile'd (with CUDA graphs) baselined.
The task:
Copy over the relevant files from segment-anything-fast into torchao's model folder and follow the readme to rerun if needed.
Write a FlexAttention version of flash_4 and measure difference in performance. If it helps, we can immediately land it in torchao, but at a minimum it could influence FlexAttention development.