I think you should put an epsilon in denominator of output of quadratic_linear_attn function to prevent NaN value when training HedgeHog MLP.
qk / (qk.sum(dim=-1, keepdim=True) +epsilon)
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
I think you should put an epsilon in denominator of output of quadratic_linear_attn function to prevent NaN value when training HedgeHog MLP. qk / (qk.sum(dim=-1, keepdim=True) +epsilon)
Upvote & Fund