facebookresearch / schedule_free

Schedule-Free Optimization in PyTorch
Apache License 2.0
1.9k stars 64 forks source link

Update tensor device on train and eval function #42

Closed zhulinchng closed 3 months ago

zhulinchng commented 3 months ago

Reopen from #41 due to issue with GitHub processing.

Added .to(p.data.device) to ensure p.data can be lerp_ with state['z'].

Avoids error relating to different devices like RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

This ensures the optimizer is allocated to the matching torch.device to avoid device mismatch errors, especially when the optimizer is resumed for training or during evaluation-only mode.

Normally on first training, the device mismatch error wouldn't error due to the use of torch.clone in the step function at self.state[p]['z'] = torch.clone(p.data). However, if the step function isn't called like during continuation of training or evaluation-only mode, a device mismatch error could occur.

zhulinchng commented 3 months ago

Hi @adefazio , I have applied the fix for eval and train functions in both sgd_schedulefree.py and adamw_schedulefree.py.

This ensures the optimizer is allocated to the matching torch.device to avoid device mismatch errors, especially when the optimizer is resumed for training or during evaluation-only mode.