DeepGraphLearning / GearNet

GearNet and Geometric Pretraining Methods for Protein Structure Representation Learning, ICLR'2023 (https://arxiv.org/abs/2203.06125)
MIT License
253 stars 28 forks source link

How can I load pretrained weights from checkpoint to go on pretraining? #10

Closed LTEnjoy closed 1 year ago

LTEnjoy commented 1 year ago

Hello! Thx for your great work! For some reasons, I couldn't run the whole training loop in your pretraining scripts. But I got some checkpoints like "model_epoch_25.pth". The question is, how can I load this checkpoint and go on finishining my pretraining? Looking forward to your reply!

Oxer11 commented 1 year ago

Hi! I haven't provided an interface for passing checkpoints for pretraining. But you can make a small change on the script to achieve this.

When runing downstream tasks, we provide the interface to pass the solver checkpoint or model checkpoint. By setting the model_checkpoint, you can pass the pretrained model checkpoint. By setting the checkpoint, you can pass both the models, mlps and optimizer into the solver. https://github.com/DeepGraphLearning/GearNet/blob/932e7e3ac56c733f31231e1e0cdd03c158c7ab96/util.py#L128-L136

For your case, if you want to continue pretraining with some checkpoints. You need to first change the save function in pretrain.py, since the current one only saves the model and discards the mlp and optimizer. Then, you need to add the interface checkpoint in build_pretrain_solver for passing the solver checkpoint, which should be similar to the build_downstream_solver shown above. Finally, you need to change the config files to add the checkpoint argument and provide the path to your checkpoint.

Nevertheless, I believe that even only pretraining the model for 25 epochs, you can still get a very strong model.

LTEnjoy commented 1 year ago

OK! Thank you for so detailed answer! I'd try it later!