Closed zhulinchng closed 3 months ago
Hi @adefazio , I have applied the fix for eval
and train
functions in both sgd_schedulefree.py and adamw_schedulefree.py.
This ensures the optimizer is allocated to the matching torch.device to avoid device mismatch errors, especially when the optimizer is resumed for training or during evaluation-only mode.
Added
.to(p.data.device)
to ensurep.data
can belerp_
withstate['z']
.Avoids error relating to different devices like
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
This ensures the optimizer is allocated to the matching torch.device to avoid device mismatch errors, especially when the optimizer is resumed for training or during evaluation-only mode.
Normally on first training, the device mismatch error wouldn't error due to the use of torch.clone in the step function at
self.state[p]['z'] = torch.clone(p.data)
. However, if the step function isn't called like during continuation of training or evaluation-only mode, a device mismatch error could occur.