test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

How is training stabilized? #21

Closed Cascol-Chen closed 2 months ago

Cascol-Chen commented 2 months ago

Thanks for sharing such a great work. However, since $\theta_K$ and $\theta_V$ share the same dimension, and $f(\cdot;\cdot)$ is simply a linear function, I have the following questions:

  1. How is training stabilized? Would optimization lead to collapse where $\theta_K=\theta_V$ and $W$ actually capture nothing?
  2. What does $W$ capture? Since Eqn. (4) also relies on $\theta_K$ for reconstruction, how is $W$ analyzed independently?
karan-dalal commented 2 months ago

There are a few reasons why collapse doesn't occur:

W is the weights of f (matrix for TTT-Linear, 2 matrices for TTT-MLP)

Cascol-Chen commented 2 months ago

Thanks for the quick response. I notice that non-linearity is LayerNorm in the implementation. However, the reconstruction_target computed here does not use LayerNorm for XK. Could your provide more explanation?

karan-dalal commented 2 months ago

The LayerNorm is not on the reconstruction target, it's included on our choice of f.

We add a residual connection to f for training stability (Sec. 2.7). So _f(x) = x + LN(fres(x)). Since we use MSE loss, we can pre-compute the reconstruction target as the difference between the residual and label view.

The LN is applied here during the forward of the test view.

Cascol-Chen commented 2 months ago

Thanks again for your quick and thoughtful response. I believe this work will broaden our understanding of TTT/TTA, and I am hopeful that TTT/TTA will become a standard practice in deep learning.