test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

The understanding of W #24

Closed Z-Z188 closed 1 month ago

Z-Z188 commented 1 month ago

image Hello author, thank you very much for your work. What does this W refer to? Is it W0 or Wb? I don't understand. Look forward to your reply, thank you!

LeoXinhaoLee commented 1 month ago

Hi, thank you for your question.

The section you are looking at discusses the update rule for a hidden state which is a K-layer MLP. Therefore, k here refers to the k-th layer in this MLP, and W^{k} means the weight matrix of that layer.

Z-Z188 commented 1 month ago

Thanks for your reply. The weight of matrix W is updated during forward: from W0,W1,W2…… to Wb. So what does the W refer to in this formula?W0?W1?Wb?Look forward to your reply. Thanks!

LeoXinhaoLee commented 1 month ago

Say \theta^0 refers to the entire parameters of a hidden state (a neural network in our case). Therefore, \theta^1, ... ,\theta^b are the updated entire parameters after each TTT mini-batch, which I think is what you are referring to.

On the other hand, W_k means the weight of the k-th layer inside this hidden state (a neural network), so \theta^t = {W^t_0, ..., W^t_K}.

Z-Z188 commented 1 month ago

365446545-29eb32ef-0005-4f44-ba5b-5764f329d608 Thanks for your reply. So when we calculate output Z,what W we use(W in the red box)?W0?W1?Wb?

LeoXinhaoLee commented 1 month ago

Say at time step b (the b-th TTT mini-batch), the hidden state value is \theta_b = {W_b^1, ..., W_b^K}.

The final output Z_b at time step b is calculated by forwarding through the K layer in \theta_b sequentially, e.g., computing Z_b^1 to Z_b^K.

Your confusion may come from in the formula, it's not specific to a time step, since it applies to any time step.

Z-Z188 commented 1 month ago

it's not specific to a time step, since it applies to any time step.

Thanks for your reply ! The sentence "it's not specific to a time step, since it applies to any time step." make me understand. Thanks for your patience.

LeoXinhaoLee commented 1 month ago

You are very welcome!

Z-Z188 commented 1 month ago

image But another question:why the W refers to W0?Is it specific to a time step?Look forward to your reply. Thanks!

LeoXinhaoLee commented 1 month ago

Here W_0 stands for the initial value of hidden states for an arbitrary time step. Here 1, ..., b refer to the 1st to b-th tokens in side an arbitrary TTT mini-batch. W_1, ..., W_b stand for the updated hidden state values for each token in that TTT mini-batch.

Therefore, it's not specific to a time step.

Z-Z188 commented 1 month ago

Thank you for your reply. May I ask when we calculate,why we should change the gradient of loss to W0 to the Z?

LeoXinhaoLee commented 1 month ago

Sorry I didn't get your question, could you point me to the specific formula you are referring to?

Z-Z188 commented 1 month ago

Sorry I didn't get your question, could you point me to the specific formula you are referring to?

image When calculating the updated W, we use the gradient of the loss with respect to W₀, but why is the gradient of the loss with respect to Z used in the subsequent formulas? Why is such a conversion necessary? image

LeoXinhaoLee commented 1 month ago

$\nabla{Z^k}l$ comes from the first factor of $\nabla{W^k0}l=\nabla{Z^k}l \cdot \left(\hat{X}^k\right)^T$. The point of dual form is by applying the associative rule of matrix multiplication, we can direct calculate the output of the second forward pass $\bar{Z}$ without explicitly materializing $\nabla_{W^k_0}l$.

Please refer to Appendix A.3 for more details.