Implements the TrainingTask and data classes with Lightning modules. The main benefit of this is the visualization of training curves through lightning logging features (tensorboard logging by default).
Example usage:
root = "../water.xyz"
avge0 = {1: -187.6043857100553, 8: -93.80219285502734}
data = LightningData(root,batch_size=4,cutoff=5.5,atomic_energies=avge0)
model = NetworkPotential(cace_representation)
task = LightningTrainingTask(model)
task.fit(data,max_epochs=125)
Support can be added in the future for schedulers, ema, etc. if necessary.
Implements the TrainingTask and data classes with Lightning modules. The main benefit of this is the visualization of training curves through lightning logging features (tensorboard logging by default).
Example usage:
Support can be added in the future for schedulers, ema, etc. if necessary.