sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.34k stars 69 forks source link

更新后的rwkv6,loss会nan #19

Closed JL-er closed 5 months ago

JL-er commented 6 months ago

我现在用的是前几天的版本loss正常

yzhangcs commented 6 months ago

Oh, looks that you may need to switch back to logsigmoid, -exp is not stable yet

JL-er commented 6 months ago

image 这是可行的loss非常稳定,基本没有误差

JL-er commented 6 months ago

image 应该是这次更新的问题

yzhangcs commented 6 months ago

This update fixes potential nans during inference, I think it's not the issue. Possibly cuz of potential inf grad of -exp, would check it, thank you

JL-er commented 6 months ago

RWKV-PEFT 添加fla,目前是可用的。但是一旦更换新fla loss就会nan,如果后续fla有更新可以告诉我 ,我可以进行测试

JL-er commented 6 months ago

image 不知道为什么fla的rwkv6,竟然没有cuda快,我之前测试gla的时候会快很多

yzhangcs commented 6 months ago

Have you compared the kernel speed

JL-er commented 6 months ago

我找时间测一下,对了还有个问题,我在做state tuning的时候,替换上fla算子会出现报错 image 应该是state没有保存梯度的原因,所以想问一下怎么解决?

yzhangcs commented 6 months ago

You can enable gradient for h0 mannually

yzhangcs commented 6 months ago

Taking h0 as learnable params would be ok? like h0 = nn.Parameter(key_dim, head_dim)

JL-er commented 6 months ago

image image image 我在使用cuda算子时是可以正常运行的,但是fla不行,正常情况state在算子计算的梯度会自动保存

JL-er commented 6 months ago

还有一点是,我这里冻结了其他所有权重只保留state的梯度

yzhangcs commented 6 months ago

ic, currently there is no access to grad of states. we will add an option later

JL-er commented 6 months ago

thank you

yzhangcs commented 5 months ago

@JL-er Hi, check it out https://github.com/sustcsonglin/flash-linear-attention/commit/1547448b998a163fdb33c49266da699db13f2dc8

Now we do not truncate grad of h states for RWKV6 for ease of state tuning Do contact us if you met any bugs or any numerical stability issues :-D

JL-er commented 5 months ago

rwkv-peft上测试非常完美,已经不需要clip了。不过之前infctx训练6000ctx len时偶尔会nan(我会重新测试) 非常感谢您