Open jjsjann123 opened 7 months ago
A similar problem showed up that. With layernorm
+ amax
. Segmenter is breaking the amax into a separate kernel.
Where it should have break amax
into two reduction kernel, where it will be able to fuse the first half into layernorm -> saving bandwidth on the both kernels.
out1 = (x + bias) + residual
out1 = out1.view(-1, w_shape[-1])
# LayerNorm
w = ln_weight if not zero_centered_gamma else 1 + ln_weight
out2 = torch.nn.functional.layer_norm(out1, w_shape, w, ln_bias, eps)
# Obtain FP8 amax
amax = torch.amax(torch.abs(out2))
amax_tensor.fill_(amax)
A similar problem showed up that. With
layernorm
+amax
. Segmenter is breaking the amax into a separate kernel. Where it should have breakamax
into two reduction kernel, where it will be able to fuse the first half into layernorm -> saving bandwidth on the both kernels.out1 = (x + bias) + residual out1 = out1.view(-1, w_shape[-1]) # LayerNorm w = ln_weight if not zero_centered_gamma else 1 + ln_weight out2 = torch.nn.functional.layer_norm(out1, w_shape, w, ln_bias, eps) # Obtain FP8 amax amax = torch.amax(torch.abs(out2)) amax_tensor.fill_(amax)
Had a quick chat with @naoyam
I think the case where we can fuse kernel is when as long as the reduction tensor matches the reference tensor of the normalization, this should be safe to fuse.
, which does not apply to the case above. But we could still break the amax into two amax and match the first one with first layernorm.
Need to think more about how this could be automatically figured out... I'm wondering if it makes sense to always decompose a full reduction and try our luck with it. :confused:
I'm putting it here for backlog.
Looking at the example below, the two models produces the same result. By pulling the reduction op ahead of the broadcast+mul, codegen seems to be getting much better performance at the toy example.
Though I'm currently hitting a bug #1659 when trying to scale it up. I'll update the problem when I patch the transpose scheduler.