YuchuanTian / DiJiang

[ICML'24 Oral] The official code of "DiJiang: Efficient Large Language Models through Compact Kernelization", a novel DCT-based linear attention mechanism.
https://arxiv.org/abs/2403.19928
86 stars 5 forks source link

Long inputs cause overflow / underflow #4

Open yuji96 opened 2 months ago

yuji96 commented 2 months ago

Thank you for sharing the implementation of the attractive work!

When training DiJiang with long inputs (>5000), the outputs were NaN. This was due to an overflow, as D2 was defined as -n powers of values less than 1. (D1 also approached zero and the query vanished). D1 and D2 can be seen as the decomposition of decay term $D^{i-j}$ with respect to distance into $D^i$ and $D^{-j}$ but this makes them numerically unstable.

https://github.com/YuchuanTian/DiJiang/blob/8a9d8da0c26a28552e7e8efb34a362d8139c3b2a/modeling/pythia-2.8B-dijiang/modeling_gpt_neox_dijiang.py#L154-L162

As an alternative, I consider incorporating DiJiang's kernel into a RetNet implementation that computes $\gamma^{i-j}$ directly, which is numerically stable. That is, FKA is changed as follows:

$$ \phi_{DCF}(x) = e^{TCx^T} $$

$$ FKA(q_i, k_j, vj) = \gamma^{i-j} \phi{DCF}(qi) \phi{DCF}(k_j)^T v_j . $$

In this way, the vector $D$ is simplified to a scalar $\gamma$ and becomes unlearnable, but the DCT kernel remains. Do you think DiJiang's strengths (i.e. linearizing a pre-trained Transformer on less data while keeping its performance) can be maintained with this change? I would appreciate your opinion on this.

HantingChen commented 2 months ago

Thanks for your insightful question.

We have compared the results of finetuning with RetNet in our paper, and both the Perplexity (PPL) and final metrics were slightly inferior to our model. However, we have not analyzed in detail the impact of the DCT transformation and learnable D.

I suggest considering a variant as a trade-off between the two: implement a solution similar to RetNet's but make gamma learnable. This could potentially avoid overflow issues while maintaining accuracy as much as possible.

yuji96 commented 2 months ago

Thank you for your prompt reply and a suggestion for improvement.

I'm interested in the effect of DCT and D's learnability. (Intuitively, DCT seems to be significantly more important)

If the representation of D after training is similar to the initial value, it can be assumed that D is acting as a bias term with respect to position throughout the training. If so, this would be an evidence to replace it with another positional bias method, as in the previous comment.

initial values of Ds:

image

image

final values of Ds: 🤔