BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
12.05k stars 827 forks source link

Questions about formula 14 and code implementation in the paper #167

Closed FengrunZhang closed 11 months ago

FengrunZhang commented 12 months ago

Hello author, amazing work! I have some puzzles while reading the code and the paper, for the function ‘att_seq’ corresponding to formula 14 in the paper: for t in range(T): kk = k[t] vv = v[t] ww = t_first + kk p = torch.maximum(pp, ww) e1 = torch.exp(pp - p) e2 = torch.exp(ww - p) sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) ww = t_decay + pp p = torch.maximum(ww, kk) e1 = torch.exp(ww - p) e2 = torch.exp(kk - p) aa = e1 * aa + e2 * vv bb = e1 * bb + e2 pp = p What I don't understand is what function torch.maximum does, in the second ’torch.maximum‘, ‘ww’ represents the sum of the previous ki and w, and ‘kk’ represents the ki at the current moment,then what is the effect of fusing the two through this function? I would be very grateful if you could answer

BlinkDL commented 11 months ago

read https://ben.bolte.cc/rwkv-model