calculate validation and test set metrics (which are per epoch metrics) on accumulated values of the batches. E.g., RMSE values are calculated on the accumulated MSE over batches
ReduceLROnPlataeu prints the current LR rate
and introduces:
E_i is scaled and shifted with the standard deviation (and, respectively, the mean) of the QM atomic energies
tolerance and sensitivity for the LR scheduler as well as the early stop callbacks are reduced.
Minor issues that this PR resolves:
converting PhysNet from PyTorch to JAX with dlpack controlling the tensor mapping lead to a striding error, i.e., the some operation in the readout layer of PhysNet led to a non-trivial memory layout of the output tensor. This has been resolved using the .contiguous method to ensure contiguous memory layout before converting it to a JAX tensor.
Details on the E_i scaling:
The total energy E is calculated as $E = \sum E_i$. The expression for $E_i$ is changed to $E_i = E_i * \sigma(E_i) + \mu(E_i)$, with $\mu(E_i)$ as the average per atom energy of the QM energies (self energies already removed).
Notes: With these changes, the initial validation set RMSE is around ~40 kJ/mol in the first epoch. From observations of multiple training runs with SchNet, it takes about 100 epochs to have a validation RMSE error of 8 kJ/mol and another 100 epochs to improve below 4 kJ/mol. Training on QM9 on a node with 4 x RTX 3090 100 epoch takes around 30 minutes.
Description
This PR solves the following inconsistencies:
and introduces:
E_i
is scaled and shifted with the standard deviation (and, respectively, the mean) of the QM atomic energiesMinor issues that this PR resolves:
.contiguous
method to ensure contiguous memory layout before converting it to a JAX tensor.Details on the
E_i
scaling: The total energyE
is calculated as $E = \sum E_i$. The expression for $E_i$ is changed to $E_i = E_i * \sigma(E_i) + \mu(E_i)$, with $\mu(E_i)$ as the average per atom energy of the QM energies (self energies already removed).Notes: With these changes, the initial validation set RMSE is around ~40 kJ/mol in the first epoch. From observations of multiple training runs with
SchNet,
it takes about 100 epochs to have a validation RMSE error of 8 kJ/mol and another 100 epochs to improve below 4 kJ/mol. Training on QM9 on a node with 4 x RTX 3090 100 epoch takes around 30 minutes.Status