ivannz / cplxmodule

Complex-valued neural networks for pytorch and Variational Dropout for real and complex layers.
MIT License
134 stars 27 forks source link

deepcopy doesn't work with cplxmodule modules #18

Closed pfeatherstone closed 3 years ago

pfeatherstone commented 3 years ago
from copy import deepcopy
import cplxmodule as cplx

n1 = cplx.nn.CplxConv2d(1,1,1)
n2 = deepcopy(n1)

This fails with

n2 = deepcopy(n1)
  File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.6/copy.py", line 306, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.6/copy.py", line 274, in _reconstruct
    y = func(*args)
  File "/usr/lib/python3.6/copyreg.py", line 88, in __newobj__
    return cls.__new__(cls, *args)
TypeError: __new__() missing 1 required positional argument: 'real'

Process finished with exit code 1
pfeatherstone commented 3 years ago

Whereas something like:

n1 = torch.nn.Conv2d(1,1,1)
n2 = deepcopy(n1)

works fine

pfeatherstone commented 3 years ago

I need this to do EMA:

class ModelEMA:
    def __init__(self, model, decay=0.9999, updates=0):
        self.ema        = deepcopy(model.module if is_parallel(model) else model).eval()  # FP32 EMA
        self.updates    = updates  # number of EMA updates
        self.decay      = lambda x: decay * (1 - math.exp(-x / 2000))  # decay exponential ramp (to help early epochs)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def to(self, device):
        self.device = device
        self.ema.to(device)

    def update(self, model):
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)
            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  # model state_dict
            ema_msd = self.ema.module.state_dict() if is_parallel(self.ema) else self.ema.state_dict()
            for k, v in ema_msd.items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1. - d) * msd[k].detach()
ivannz commented 3 years ago

I patched in support for copy.deepcopy and added some tests, which should resolve the problem. The issue was that cplx.Cplx object did not have the appropriate API outlined in copy module.