materialsvirtuallab / megnet

Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals
BSD 3-Clause "New" or "Revised" License
508 stars 158 forks source link

Training with save_checkpoint=True disables validation metrics logging #214

Open a-ws-m opened 3 years ago

a-ws-m commented 3 years ago

I've recently been training some models with MEGNet and trying to use TensorBoard to track the model metrics. At first I was very confused as to why I wasn't seeing the validation metrics in the output -- the MEGNet ModelCheckpointMAE callback was reporting improvements to the val_mae as expected, so I knew that I'd passed the validation correctly. I did some digging and found this. I understand the logic, but I don't think hiding the validation data from Keras should be default behaviour because it prevents other callbacks that track validation metrics from working as expected.

I also checked the code for the ModelCheckpointMAE callback and I noticed that the validation MAE is manually computed. The logs arguments to on_epoch_end already includes pre-computed metrics, so long as the model was compiled with those metrics. You can see in the TensorBoard callback code that it simply pulls the pre-computed validation metrics from this parameter. So it may be more efficient to ensure that the model is compiled with the mae metric by default and then pull its value from logs; this would resolve the issue of validation metrics being computed twice.

chc273 commented 3 years ago

@a-ws-m thanks for the comment. Indeed the metrics are computed manually. The original reason was that the model by design trains on intensive quantities but sometimes if we train on extensive quantities we would like to see the correct metric output by multiplying the prediction output with the number of atoms. This applies to, for example, U0 in the QM9 data. Such requirement cannot be satisfied with default API.

The current procedure only provides a convenient point of entry for training such models. If you have other needs, please always feel free to write the training procedure by calling keras’ fit APIs. The current code may work as a reference.

Meanwhile, I will look into how your suggestions can be included. Thanks!