test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

Replicating the experiments #16

Closed LuoyaoChen closed 3 months ago

LuoyaoChen commented 3 months ago

Hi,

Thanks for the great work! I wonder two questions

So to replicate the work, one needs to run two scripts:

a) step 1: training using the training data, to update K, Q, V
b) step 2: test time learning, which includes (loading the trained weights K, Q, V) + (initiate TTT) + (update TTT-linear/MLP)?

Is this understanding correct?

The second question is: In step 1, does the code support loading pre-trained transformer weights to save time?

Thank you so much!

karan-dalal commented 3 months ago

Hi, thank you for your interest in our work!

Your understanding is correct, to a certain extent. During pre-training, we train the KQV, the initialization of the TTT inner network (Linear / MLP), and the remaining network components (Pre-LN, FFN etc). During test-time, we load the pre-trained weights, and update the TTT inner network on the sequence.

Our code should support loading trained weights, however, we haven't released any (see https://github.com/test-time-training/ttt-lm-jax/issues/1).

To replicate results from the paper, we suggest using our JAX codebase. It includes training code, as well as scripts that directly replicate experiments from our paper.

Please let me know if you any other questions!

LJHzju commented 2 months ago

Hi, I would like to know whether the model needs to use test time learning to update the relevant parameters of the inner loop during the training period.