test-time-training / ttt-lm-pytorch

Official PyTorch implementation of Learning to (Learn at Test Time): RNNs with Expressive Hidden States
MIT License
1.01k stars 56 forks source link

The current code doesn't load any model. #20

Closed MeNicefellow closed 2 months ago

karan-dalal commented 2 months ago

This code is just for inference.

MeNicefellow commented 2 months ago

This code is just for inference.

For inference, you need to load the checkpoint first, or it is randomly initialized. How can you infer with that?

karan-dalal commented 2 months ago

You can use our JAX codebase to train a model, then load it into PyTorch.

MeNicefellow commented 2 months ago

You can use our JAX codebase to train a model, then load it into PyTorch.

But you already have trained models here https://huggingface.co/Test-Time-Training. Why don't you show how to load them?

Z-Z188 commented 1 month ago

You can use our JAX codebase to train a model, then load it into PyTorch.

But you already have trained models here https://huggingface.co/Test-Time-Training. Why don't you show how to load them?

Hello, author! I have the same question. Could you teach us how to load the train models?

karan-dalal commented 1 month ago

https://github.com/test-time-training/ttt-lm-pytorch/issues/22

MeNicefellow commented 1 month ago

@karan-dalal That is about converting to Pytorch. I wonder why it is so hard for you to provide the complete example code to load the model in pytorch and perform inference by updating the example in the README.md file. Your model is such a different thing from existing models, a user friendly example would be so helpful!

karan-dalal commented 1 month ago

Our model wraps the HuggingFace PreTrained class. Once you convert the PyTorch checkpoint, you can simply load with something like:

  model = TttForCausalLM.from_pretrained(
      pt_args.weight_path,
      torch_dtype=torch.float32,
      device_map="auto",
      **pt_args.model_args,
  )
helldog-star commented 1 month ago

Our model wraps the HuggingFace PreTrained class. Once you convert the PyTorch checkpoint, you can simply load with something like:

  model = TttForCausalLM.from_pretrained(
      pt_args.weight_path,
      torch_dtype=torch.float32,
      device_map="auto",
      **pt_args.model_args,
  )

The latest version of transformers does not have the 'ttt' model, nor does it contain the 'TttForCausalLM' class. Could you please provide the version of transformers you are using? Or could you provide the code for 'TttForCausalLM'?