ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
415 stars 157 forks source link

RMSE increases when training on multi-gpu #343

Closed gconter closed 2 months ago

gconter commented 3 months ago

I am not a ML expert, so please forgive me if I am not clear of if I ask stupid questions (or both).

I tried to run the same (same settings such as epochs, batch size, learning rate, energy/forces weights, hidden irreps, and so on) training script for MACE on single and 16 gpu (using the multi-gpu branch). I experience a roughly 10x increase in the RMSE on energy and a 3x on forces when training on multiple gpu, despite the loss function being essentially comparable. Is the model indeed less accurate (looking at RMSE), or not (looking at the loss)? In the former case, how do I regain the lost accuracy?

I unsuccessfully tried to adjust the batch size (from 10 to 20, 40, 60 and 80) and the learning rate (from 0.01 to 0.013 and 0.02) and both changes worsened the situation.

Thank you in advance.

ilyes319 commented 3 months ago

Hey,

I would need to see the different log files for the different runs to understand what is happening. Are you using the same branch for the single GPU and multi-GPU trainings?

One does expect differences between multiple GPUs and single GPU but only in the large batch size limit and not to this magnitude if converged properly. When you train on multi-GPU, the effective batch size is the number of GPUs times the batch size per GPU. MACE is trained using stochastic gradient descent. A larger batch size means less stochasticity, which can prevent good convergence in some cases. Moreover, when comparing two runs with different batch sizes, you need to compare them with equal number of gradient updates. The bigger the batch size, the fewer gradient updates the model receives for the same number of epochs.

I would recommend you decrease the batch size from 10 to 5 to 3 and/or train for longer. But again, seeing the log files would help.