Closed LTEnjoy closed 1 year ago
Hi! I haven't provided an interface for passing checkpoints for pretraining. But you can make a small change on the script to achieve this.
When runing downstream tasks, we provide the interface to pass the solver checkpoint or model checkpoint. By setting the model_checkpoint
, you can pass the pretrained model checkpoint. By setting the checkpoint
, you can pass both the models, mlps and optimizer into the solver.
https://github.com/DeepGraphLearning/GearNet/blob/932e7e3ac56c733f31231e1e0cdd03c158c7ab96/util.py#L128-L136
For your case, if you want to continue pretraining with some checkpoints. You need to first change the save
function in pretrain.py
, since the current one only saves the model and discards the mlp and optimizer. Then, you need to add the interface checkpoint
in build_pretrain_solver
for passing the solver checkpoint, which should be similar to the build_downstream_solver
shown above. Finally, you need to change the config files to add the checkpoint
argument and provide the path to your checkpoint.
Nevertheless, I believe that even only pretraining the model for 25 epochs, you can still get a very strong model.
OK! Thank you for so detailed answer! I'd try it later!
Hello! Thx for your great work! For some reasons, I couldn't run the whole training loop in your pretraining scripts. But I got some checkpoints like "model_epoch_25.pth". The question is, how can I load this checkpoint and go on finishining my pretraining? Looking forward to your reply!