Closed Cascol-Chen closed 2 months ago
There are a few reasons why collapse doesn't occur:
W is the weights of f (matrix for TTT-Linear, 2 matrices for TTT-MLP)
Thanks for the quick response. I notice that non-linearity is LayerNorm
in the implementation. However, the reconstruction_target computed here does not use LayerNorm
for XK. Could your provide more explanation?
The LayerNorm is not on the reconstruction target, it's included on our choice of f.
We add a residual connection to f for training stability (Sec. 2.7). So _f(x) = x + LN(fres(x)). Since we use MSE loss, we can pre-compute the reconstruction target as the difference between the residual and label view.
The LN is applied here during the forward of the test view.
Thanks again for your quick and thoughtful response. I believe this work will broaden our understanding of TTT/TTA, and I am hopeful that TTT/TTA will become a standard practice in deep learning.
Thanks for sharing such a great work. However, since $\theta_K$ and $\theta_V$ share the same dimension, and $f(\cdot;\cdot)$ is simply a linear function, I have the following questions: