geoelements / gns

Graph Network Simulator
https://www.geoelements.org/gns/
Other
136 stars 33 forks source link

Average training and validation losses using torch.distributed for monitoring trends #74

Closed bumi001 closed 4 months ago

bumi001 commented 5 months ago
  1. Added code for deterministic computation
    • helps in reproducing results
    • expects torch and cuda versions to be the same
  2. Added code to calculate average training loss over the entire training dataset across all gpus
    • uses torch.distributed.reduce
    • printed on a per epoch basis
    • helps monitor training trend.
  3. Added code to calculate average validation loss over the entire validation dataset across all gpus
    • uses torch.distributed.reduce
    • printed whenever it is calculated
    • helps monitor validation trend.
  4. Added code to save the history of average training loss and average validation loss, as needed, as a function of epoch.