Open RedRAINXXXX opened 1 month ago
Which PyTorch version are you using? I'm wondering if this is related to #1217, could you try that fix?
My torch version is: Version: 2.4.1+cu124
I've tried this (setting NVTE_TORCH_COMPILE to 1) but it didn't work, and I tried removing the @jit_fuser decorator but it also didn't work. :(
The size of softmax_lse and softmax_lse_per_step is (36, 13056), and this will cause a bug when calling movedim(2, seq_dim).
Version Name: flash-attn Version: 2.6.3
Name: transformer-engine Version: 1.11.0+c27ee60
Name: flashattn-hopper Version: 3.0.0b1
Bug report The bug occurs in this function:
@jit_fuser def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step): """Merge partial outputs of each step in Attention with context parallelism""" softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) out_corrected = out_per_step * softmax_lse_correctedexp out.add(out_corrected)
The final tracing info is: torch._dynamo.exc.TorchRuntimeError: Failed running call_method movedim(*(FakeTensor(..., device='cuda:4', size=(36, 13056)), 2, 1), **{}):