NVIDIA / TorchFort

An Online Deep Learning Interface for HPC programs on NVIDIA GPUs
https://nvidia.github.io/TorchFort/
Other
154 stars 19 forks source link

Fix and improve cross-device training restart support. #23

Closed azrael417 closed 1 month ago

azrael417 commented 1 month ago

This PR enables reloading training checkpoints on a different device than used initially for training. Previously, the optimizer state tensors were saved on the training device, which caused issues when loading the optimizer on a different device (i.e. training would halt due to state tensors being located on a different device). Turns out torch::load does support a device placement argument, so this PR adds that to fix the issue. We also use this functionality to load models directly to the target training device instead of the existing logic that routed things through the CPU.

This PR additionally adds tests for checkpoint functionality.