test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

How to load pre-trained weights from Hugging Face? #22

Closed helldog-star closed 2 months ago

karan-dalal commented 2 months ago

We have released some public checkpoints for research use. Note that these checkpoints are trained on Chinchilla optimal token count and will not generate any coherent text.

You can convert the Huggingface checkpoint from JAX format to PyTorch format using this script: https://gist.github.com/xvjiarui/835bab0a867cd1f4a20714a5f860abfd

MeNicefellow commented 1 month ago

We have released some public checkpoints for research use. Note that these checkpoints are trained on Chinchilla optimal token count and will not generate any coherent text.

You can convert the Huggingface checkpoint from JAX format to PyTorch format using this script: https://gist.github.com/xvjiarui/835bab0a867cd1f4a20714a5f860abfd

ModuleNotFoundError: No module named 'transformers.models.ttt'

My transformers is already up-to-date. Please try your code in a new environment and then post. It is such a simple principle.

juntaic7 commented 1 month ago

We have released some public checkpoints for research use. Note that these checkpoints are trained on Chinchilla optimal token count and will not generate any coherent text. You can convert the Huggingface checkpoint from JAX format to PyTorch format using this script: https://gist.github.com/xvjiarui/835bab0a867cd1f4a20714a5f860abfd

ModuleNotFoundError: No module named 'transformers.models.ttt'

My transformers is already up-to-date. Please try your code in a new environment and then post. It is such a simple principle.

Hi, I'm wondering has this problem been solved? Or we should build transformers from source and add the package into it?

helldog-star commented 1 month ago

We have released some public checkpoints for research use. Note that these checkpoints are trained on Chinchilla optimal token count and will not generate any coherent text. You can convert the Huggingface checkpoint from JAX format to PyTorch format using this script: https://gist.github.com/xvjiarui/835bab0a867cd1f4a20714a5f860abfd

ModuleNotFoundError: No module named 'transformers.models.ttt' My transformers is already up-to-date. Please try your code in a new environment and then post. It is such a simple principle.

Hi, I'm wondering has this problem been solved? Or we should build transformers from source and add the package into it?

I think I've solved the problem. The TTT model is not included in the latest version of Transformers. You need to import TTTForCausalLM and TTTConfig from ttt.py in this codebase. :)