Closed MeNicefellow closed 2 months ago
This code is just for inference.
For inference, you need to load the checkpoint first, or it is randomly initialized. How can you infer with that?
You can use our JAX codebase to train a model, then load it into PyTorch.
You can use our JAX codebase to train a model, then load it into PyTorch.
But you already have trained models here https://huggingface.co/Test-Time-Training. Why don't you show how to load them?
You can use our JAX codebase to train a model, then load it into PyTorch.
But you already have trained models here https://huggingface.co/Test-Time-Training. Why don't you show how to load them?
Hello, author! I have the same question. Could you teach us how to load the train models?
@karan-dalal That is about converting to Pytorch. I wonder why it is so hard for you to provide the complete example code to load the model in pytorch and perform inference by updating the example in the README.md file. Your model is such a different thing from existing models, a user friendly example would be so helpful!
Our model wraps the HuggingFace PreTrained class. Once you convert the PyTorch checkpoint, you can simply load with something like:
model = TttForCausalLM.from_pretrained(
pt_args.weight_path,
torch_dtype=torch.float32,
device_map="auto",
**pt_args.model_args,
)
Our model wraps the HuggingFace PreTrained class. Once you convert the PyTorch checkpoint, you can simply load with something like:
model = TttForCausalLM.from_pretrained( pt_args.weight_path, torch_dtype=torch.float32, device_map="auto", **pt_args.model_args, )
The latest version of transformers does not have the 'ttt' model, nor does it contain the 'TttForCausalLM' class. Could you please provide the version of transformers you are using? Or could you provide the code for 'TttForCausalLM'?
This code is just for inference.