BingqingCheng / cace

MIT License
48 stars 10 forks source link

Addition of lightning trainer and data modules #5

Open dking072 opened 1 month ago

dking072 commented 1 month ago

Implements the TrainingTask and data classes with Lightning modules. The main benefit of this is the visualization of training curves through lightning logging features (tensorboard logging by default).

Example usage:

root = "../water.xyz"
avge0 = {1: -187.6043857100553, 8: -93.80219285502734}
data = LightningData(root,batch_size=4,cutoff=5.5,atomic_energies=avge0)

model = NetworkPotential(cace_representation)
task = LightningTrainingTask(model)
task.fit(data,max_epochs=125)

Support can be added in the future for schedulers, ema, etc. if necessary.