ml4bio / e2efold

pytorch implementation for "RNA Secondary Structure Prediction By Learning Unrolled Algorithms"
MIT License
106 stars 17 forks source link

When to early stop to get the good performance pre-trained models? #9

Open irleader opened 3 years ago

irleader commented 3 years ago

I tested e2efold with the pre-trained models provided, and the results on RNAstralign are very close (a little bit lower) to those in the paper,but results on ArchiveII is a bit lower than those in the paper (is this within reasonable range?) :

Rnastralign short: Average testing F1 score with learning post-processing: 0.84188575 Average testing F1 score with learning post-processing allow shift: 0.8534005 Average testing precision with learning post-processing: 0.8722457 Average testing precision with learning post-processing allow shift: 0.88652223 Average testing recall with learning post-processing: 0.8236133 Average testing recall with learning post-processing allow shift: 0.83382463

Rnastralign long: Average testing F1 score with learning post-processing: 0.7842659 Average testing F1 score with learning post-processing allow shift: 0.7964913 Average testing precision with learning post-processing: 0.8558946 Average testing precision with learning post-processing allow shift: 0.8695002 Average testing recall with learning post-processing: 0.72498244 Average testing recall with learning post-processing allow shift: 0.7361069

Archiveii short: Average testing F1 score with learning post-processing: 0.55401367 Average testing F1 score with learning post-processing allow shift: 0.5792161 Average testing precision with learning post-processing: 0.6059438 Average testing precision with learning post-processing allow shift: 0.63850296 Average testing recall with learning post-processing: 0.53043944 Average testing recall with learning post-processing allow shift: 0.5530917

Archiveii long: Average testing F1 score with learning post-processing: 0.11524726 Average testing F1 score with learning post-processing allow shift: 0.13892774 Average testing precision with learning post-processing: 0.12628905 Average testing precision with learning post-processing allow shift: 0.15100932 Average testing recall with learning post-processing: 0.10790618 Average testing recall with learning post-processing allow shift: 0.13134095

In order to reproduce the training process, I trained e2efold as well. While the results are terrible, I believe this is due to both pre-processing and post-processing networks are trained at a fixed number of epochs without early stopping.

Can you kindly indicate when to stop training pre-processing and post-processing networks respectively to get the pre-trained model you provided? (at which epoches/or any code you use to stop) And where to modify these epoch numbers? (epoches_first and epoches_third in config.json?)