bregaldo / pywph

Wavelet Phase Harmonics in PyTorch for Images
BSD 3-Clause "New" or "Revised" License
15 stars 3 forks source link

Autograd broken with new version #3

Closed NiallJeffrey closed 2 years ago

NiallJeffrey commented 2 years ago

In the new version 1.0, the autograd is broken. In particular, the example no longer work: https://github.com/bregaldo/pywph/blob/main/examples/compute_wph_grad.py. See the below output.

If I reload the old version (or the current cross branch), this all works fine. This seems like quite a bad bug at the moment.

RuntimeError Traceback (most recent call last)

<ipython-input-2-a76606fa12d8> in <module>
     20 for i in range(nb_chunks):
     21     print(f"{i}/{nb_chunks}")
---> 22     coeffs_chunk = wph_op(data_torch, i)
     23     loss_chunk = (torch.absolute(coeffs_chunk) ** 2).sum()
     24     loss_chunk.backward(retain_graph=True)

/mesopsl3/home/njeffrey/wphenv/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/mesopsl3/home/njeffrey/wphenv/lib/python3.7/site-packages/pywph-1.0-py3.7.egg/pywph/wph_operator.py in forward(self, data, chunk_id, requires_grad, norm, ret_indices, pbc, ret_wph_obj, cross)
   1302         Alias of apply.
   1303         """
-> 1304         return self.apply(data, chunk_id=chunk_id, requires_grad=requires_grad, norm=norm, ret_indices=ret_indices, pbc=pbc, ret_wph_obj=ret_wph_obj, cross=cross)

/mesopsl3/home/njeffrey/wphenv/lib/python3.7/site-packages/pywph-1.0-py3.7.egg/pywph/wph_operator.py in apply(self, data, chunk_id, requires_grad, norm, ret_indices, pbc, ret_wph_obj, cross)
   1276                 coeffs, indices = self._apply_cross_chunk(data[0], data[1], chunk_id, norm, pbc)
   1277             else:
-> 1278                 coeffs, indices = self._apply_chunk(data, chunk_id, norm, pbc)
   1279 
   1280         # We free memory when needed

/mesopsl3/home/njeffrey/wphenv/lib/python3.7/site-packages/pywph-1.0-py3.7.egg/pywph/wph_operator.py in _apply_chunk(self, data, chunk_id, norm, pbc)
   1011 
   1012             # Select the different translations
-> 1013             cov = cov[...,  curr_id_cov_indices - curr_id_cov_indices[0], curr_translation_pos[:, 1], curr_translation_pos[:, 0]] # (..., nb_wph_chunk)
   1014 
   1015             # Normalisation if padding

RuntimeError: index does not support automatic differentiation for outputs with complex dtype.
Eralys commented 2 years ago

Hi Niall, I just ran it, and it went smoothly. I'm using this version of torch : '1.10.0+cu102'

bregaldo commented 2 years ago

Which version of PyTorch are you using? It is indeed most likely a problem of version, pywph currently requires torch>=1.8.0.

NiallJeffrey commented 2 years ago

Yep, using 1.8

bregaldo commented 2 years ago

OK, with the recent update it seems that now PyWPH necessitates at least torch>=1.9.0. I will update the requirements. Thank you for the feedback!

NiallJeffrey commented 2 years ago

Ok - I've rerun the install procedure with the new requirements and the gradient example now works fine 👍