facebookresearch / schedule_free

Schedule-Free Optimization in PyTorch
Apache License 2.0
1.91k stars 65 forks source link

Update tensor device on train and eval function #41

Closed zhulinchng closed 3 months ago

zhulinchng commented 3 months ago

Added .to(p.data.device) to ensure p.data can be lerp_ with state['z'].

Avoids error relating to different devices like RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Reason for PR: During evaluation-only mode, I noticed that state['z'] was loaded on cpu when p.data is in cuda.

Let me know if this fix is relevant and I'll do it for the other optimizers' eval function.

adefazio commented 3 months ago

Seems like a reasonable fix, thanks! Let me know once it's applied for all the eval functions.

zhulinchng commented 3 months ago

To reopen as current PR is stuck at "Processing updates" >5 hours