test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

Why is the sequence length increasing as the auto-regressive decoding progresses? #17

Closed TheTinyTeddy closed 2 months ago

TheTinyTeddy commented 2 months ago

Hi,

I was wondering in the code why is the sequence length increases as more tokens are predicted, since this is similar to how transformer decoder does inference.

My understanding of TTT is that it is a RNN with a per token weight update mechanism, but the sequence length should always be 1. Could you tell me if I am missing something?

xvjiarui commented 2 months ago

Hi @TheTinyTeddy Thanks for your interest in our work. Your desired behavior needs to set up cache during decoding. To enable this, you need to set configuration = TTTConfig(use_cache=True). Please let me know if this resolve your issue.

TheTinyTeddy commented 2 months ago

Thank you for the reply!

So is there a similar notion of "KV cache" in TTT, just like in Transformer decoder?

xvjiarui commented 2 months ago

Yes. But the cache size doesn't increase with sequence length. It's fixed size like any other RNN