This PR enables reloading training checkpoints on a different device than used initially for training. Previously, the optimizer state tensors were saved on the training device, which caused issues when loading the optimizer on a different device (i.e. training would halt due to state tensors being located on a different device). Turns out torch::load does support a device placement argument, so this PR adds that to fix the issue. We also use this functionality to load models directly to the target training device instead of the existing logic that routed things through the CPU.
This PR additionally adds tests for checkpoint functionality.
This PR enables reloading training checkpoints on a different device than used initially for training. Previously, the optimizer state tensors were saved on the training device, which caused issues when loading the optimizer on a different device (i.e. training would halt due to state tensors being located on a different device). Turns out
torch::load
does support a device placement argument, so this PR adds that to fix the issue. We also use this functionality to load models directly to the target training device instead of the existing logic that routed things through the CPU.This PR additionally adds tests for checkpoint functionality.