Closed pfeatherstone closed 3 years ago
Whereas something like:
n1 = torch.nn.Conv2d(1,1,1)
n2 = deepcopy(n1)
works fine
I need this to do EMA:
class ModelEMA:
def __init__(self, model, decay=0.9999, updates=0):
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
def to(self, device):
self.device = device
self.ema.to(device)
def update(self, model):
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
ema_msd = self.ema.module.state_dict() if is_parallel(self.ema) else self.ema.state_dict()
for k, v in ema_msd.items():
if v.dtype.is_floating_point:
v *= d
v += (1. - d) * msd[k].detach()
I patched in support for copy.deepcopy
and added some tests, which should resolve the problem. The issue was that cplx.Cplx
object did not have the appropriate API outlined in copy module.
This fails with