Open WorldEditors opened 13 hours ago
Thank you for reporting this issue. Could you elaborate more on the input shapes so that I can do some simulation exps
some debugs showing the following part in chunk.py causes the problem
1016 grid = (NV, NT * NC, B * HQ)
1017 chunk_gsa_bwd_kernel_intra_KV[grid](
1018 v, g, o, A, do, dv, dg,
1019 v.stride(1), v.stride(2),
1020 T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG,
1021 OVERWRITE_DG=overwrite_dg,
1022 num_warps=num_warps,
1023 num_stages=num_stages
1024 ) # After this function dv and dg suddenly goes to NAN
1025 return dq, dk, dv, dg, dh0
hyper parameters: hidden_size = 512, num_slots=64, nheads=4
Others are just default (GatedSlotAttention function default)
@WorldEditors How about the sequence length
tried sequence length with 3K, 12K, 24K, all possible to reproduce the problem
In GSA, I tried restore training from a problematic checkpoint, and there are some interesting discoveries:
The training will be OK only if I reinitialize the f_proj
parameters.
While reinitialize the other parameters (q_proj
, k_proj
, v_proj
, g_norm
, o_proj
) won't solve the problem.
So I guess there is something wrong with the f_proj parameter?
But why there is similar problems in RWKV6? I have no idea
Hi, Thanks for reporting it! Do you have input tensor and model weight such that we can reproduce it?
Describe the bug
Running training for GSA and RWKV will result in NAN gradient occasionally, rare at the beginning stage, but getting more frequent as the training processes. I checked parameters and losses, all of which are reasonable and shows no sign of explosion. The NAN comes suddenly. By switching model back to Transformer, this never happens
by using with torch.detect_anomaly(), get the following log:
Steps to reproduce the bug
Can not provide a code sample, it did not happen.in specific model in specific steps.
Expected behavior
N/A
Environment info
CUDA: 12.6, NVIDIA A800 80GB PCIe