Added .to(p.data.device) to ensure p.data can be lerp_ with state['z'].
Avoids error relating to different devices like RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Reason for PR: During evaluation-only mode, I noticed that state['z'] was loaded on cpu when p.data is in cuda.
Let me know if this fix is relevant and I'll do it for the other optimizers' eval function.
Added
.to(p.data.device)
to ensurep.data
can belerp_
withstate['z']
.Avoids error relating to different devices like
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Reason for PR: During evaluation-only mode, I noticed that
state['z']
was loaded oncpu
whenp.data
is incuda
.Let me know if this fix is relevant and I'll do it for the other optimizers' eval function.