test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

How to use TTTLinear/TTTMLP like RNNs #9

Closed zx1292982431 closed 3 months ago

zx1292982431 commented 3 months ago

How can I use TTTLinear/TTTMLP to perform time series modeling for data in the shape of [B, T, C], just like using RNN? Could you give me a sample?

LeoXinhaoLee commented 3 months ago

Hi, thanks for your question. TTT layers are essentially RNN blocks. So if you can use an RNN for your application, you should be able to use TTT layers as well.

zx1292982431 commented 3 months ago

Thanks for your reply! I init a TTTLinear and try to forward like:

from ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS,TTTLinear
import torch

configuration = TTTConfig()

model = TTTLinear(configuration,layer_idx=0)
X =  torch.rand(1,101,2048) #[B,T,C]
Y = model(X)

but I got the error :

Traceback (most recent call last):
  File "/home/lizixuan/workspace/projects/tmp_coding/TTTtest.py", line 8, in <module>
    Y = model(X)
  File "/home/lizixuan/miniconda3/envs/lzx/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lizixuan/miniconda3/envs/lzx/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/raid2/user_space/lizixuan/projects/tmp_coding/ttt.py", line 869, in forward
    cos, sin = self.rotary_emb(XV, position_ids % self.mini_batch_size)
TypeError: unsupported operand type(s) for %: 'NoneType' and 'int'

I notice that the position_ids is optional, is it necessary? Please point me out where I am wrong. Thanks again.

LeoXinhaoLee commented 3 months ago

Specific to this paper, we use rope positional embedding to distinguish different positions in one TTT mini-batch (and in this paper, there's only 16 positions in there), which empirically improves performance a bit.

Therefore, the TTTLinear and TTTMLP modules in this code require position_ids, which is generated here: https://github.com/test-time-training/ttt-lm-pytorch/blob/4ddc7732cae4ac44884804ef9febf2a9a435c3f8/ttt.py#L1454

Dushuai12138 commented 3 months ago

I try to use TTTmodel

from transformers import AutoTokenizer
from ttt import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS

configuration = TTTConfig()
model4 = TTTModel(configuration).to("cuda")
print(model4)

import torch
batch_size, length, dim1 = 1, 20, 2048
x = torch.randint(0,10,(batch_size, length, dim1)).to("cuda")

model4.train()
y = model4(x)

and get error:

File /scratch/yyzha/dushuai/SMPre/jupyter/ttt.py:865, in TTTBase.forward(self, hidden_states, attention_mask, position_ids, cache_params)
    862 XQ, XK, XV = self.get_qkv_projections(hidden_states, cache_params=cache_params)
    864 # [B, L, C] -> [B, L, num_heads, head_dim] -> [B, num_heads, L, head_dim]
--> 865 XQ = XQ.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
    866 XK = XK.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
    867 XV = XV.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)

RuntimeError: shape '[1, 20, 32, 64]' is invalid for input of size 83886080
LeoXinhaoLee commented 3 months ago

Hi, I think it's because our language model is defined to take in token indices as input, so the input shape should be [B, T], where each element is the id of a token. Please see https://github.com/test-time-training/ttt-lm-pytorch/blob/4ddc7732cae4ac44884804ef9febf2a9a435c3f8/ttt.py#L1446

Dushuai12138 commented 3 months ago

Hi, I think it's because our language model is defined to take in token indices as input, so the input shape should be [B, T], where each element is the id of a token. Please see

https://github.com/test-time-training/ttt-lm-pytorch/blob/4ddc7732cae4ac44884804ef9febf2a9a435c3f8/ttt.py#L1446

yeah, it works. Thanks a lot.