test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

Accelerating Torch trainging for TTT #26

Closed jongjyh closed 1 month ago

jongjyh commented 1 month ago

Hi, authors!

Thank you for open-source such a great project. I would like to know how to accelerate the training speed of TTT-torch.

Also, What may cause the speed diff between Torch and Jax?

Looking forward to your reply. Thanks!

Best regards.

karan-dalal commented 1 month ago

JAX just-in-time compilation is pretty good. We're working on training speed and will release something soon.