test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

RuntimeError: Trying to backward through the graph a second time with custom `ttt_layer` in self-attention #15

Closed maojiaqi111 closed 3 months ago

maojiaqi111 commented 3 months ago

Hello,

I have replaced a model's self-attention with TTTLinear using the code provided. When performing loss.backward on the first batch, there are no issues. However, during loss.backward() on the second batch, I encounter the following error:

File "train_engine.py", line 43, in train_one_epoch
scaler.scale(losses["cost"]).backward()
File "train.py", line 194, in main
train_one_epoch(
File "train.py", line 255, in
main()
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

It seems to be related to the release of some intermediate values in the ttt_layer or related to gradient computation. Since I'm not very familiar with the internal gradient calculations of ttt_layer, I cannot pinpoint the faulty code.

Do you have any insights or suggestions on how to resolve this issue?

Thank you!


You can copy and paste this text into your GitHub issue.

LuoyaoChen commented 3 months ago

Hi, did you remove the transformer's K,Q,V entirely, and replaced it with TTT-defined K,Q,V's?

xvjiarui commented 3 months ago

Hi @maojiaqi111

Is it possible to provide some code to reproduce your error? Or did you try to directly train the provided TTT model instead of modifying self-attention?

LuoyaoChen commented 3 months ago

Hi, @xvjiarui Thanks for your question, +1 Also, I wonder whether it's ever feasible to load k, q, v from the pretrained transformer (standard transformer's self-attention, not the k,q,v in ttt.linear) directly/easily? Or does one have to map each key to ttt.linear manually? If the latter, then I guess the mapping would require very delicate efforts and susceptible to errors?

Thank you!

karan-dalal commented 3 months ago

@LuoyaoChen Hi. The QKV in a transformer are analogous to those in TTT. You should be able to load them and they should correspond with the same shape. Page 7:

Screenshot 2024-07-16 at 9 44 25 PM