airlab-unibas / airlab

Image registration laboratory for 2D and 3D image data
Apache License 2.0
408 stars 92 forks source link

Can't use EarlyStopping in registration #13

Open bsugerman opened 5 years ago

bsugerman commented 5 years ago

Hi guys, at least in python 3.7 and torch 1.1.0, EarlyStopping causes the registration.start() of a PairwiseTransformation to break, with the error:

~\AppData\Local\Continuum\anaconda3\lib\site-packages\airlab-0.2.1-py3.7.egg\airlab\registration\registration.py in start(self, EarlyStopping, StopPatience)
    143                     n = 0
    144                     self.loss=loss
--> 145                     best=deepcopy(self._transformation)
    146                 else:
    147                     n += 1

~\AppData\Local\Continuum\anaconda3\lib\copy.py in deepcopy(x, memo, _nil)
    178                     y = x
    179                 else:
--> 180                     y = _reconstruct(x, memo, *rv)
    181 
    182     # If is its own copy, don't memoize.

~\AppData\Local\Continuum\anaconda3\lib\copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    278     if state is not None:
    279         if deep:
--> 280             state = deepcopy(state, memo)
    281         if hasattr(y, '__setstate__'):
    282             y.__setstate__(state)

~\AppData\Local\Continuum\anaconda3\lib\copy.py in deepcopy(x, memo, _nil)
    148     copier = _deepcopy_dispatch.get(cls)
    149     if copier:
--> 150         y = copier(x, memo)
    151     else:
    152         try:

~\AppData\Local\Continuum\anaconda3\lib\copy.py in _deepcopy_dict(x, memo, deepcopy)
    238     memo[id(x)] = y
    239     for key, value in x.items():
--> 240         y[deepcopy(key, memo)] = deepcopy(value, memo)
    241     return y
    242 d[dict] = _deepcopy_dict

~\AppData\Local\Continuum\anaconda3\lib\copy.py in deepcopy(x, memo, _nil)
    159             copier = getattr(x, "__deepcopy__", None)
    160             if copier:
--> 161                 y = copier(memo)
    162             else:
    163                 reductor = dispatch_table.get(cls)

~\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\tensor.py in __deepcopy__(self, memo)
     21     def __deepcopy__(self, memo):
     22         if not self.is_leaf:
---> 23             raise RuntimeError("Only Tensors created explicitly by the user "
     24                                "(graph leaves) support the deepcopy protocol at the moment")
     25         if id(self) in memo:

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

Makes sense since deepcopy can't copy torch tensors. I think a solution is in the PairwiseRegistration class, by changing the line

best=deepcopy(self._transformation)

to just extract the displacement tensor, and then when early stopping is reached, replacing the displacement tensor in the self._transformation class with best. The problem is that I can extract the tensor out of the PairwiseRegistration class (e.g. by detach()) but I can't for the life of me figure out how to replace it. I'm hoping you might know the trick?

linson7017 commented 2 years ago

the problem is still there. anyone can resolve it?

mysakbm commented 2 years ago

I can confirm that the problem is still there.