Chris-Pedersen / LearnableWavelets

Learnable wavelet neural networks
5 stars 0 forks source link

Drop any kymatio imports #40

Closed Chris-Pedersen closed 2 years ago

Chris-Pedersen commented 2 years ago

Given we want to build statistical isotropy into the model and have more freedom in the numbers of first and second order filters, we will want more flexibility in generating the wavelets and performing the convolutions than Kymatio is intended for. So in this PR we want to take the key computational parts of Kymatio (i.e. the scattering convolutions) and integrate them into our code, and detatch from any other Kymatio imports. This first commit is a crude first stab at this, but there's a lot more tidying to be done. A few major points:

  1. There's duplicate code between Kymatio and the parametric scattering network (PSN) code to generate wavelet filters. filter_bank(..) from Kymatio is called, followed by raw_morlet(..) from PSN. We should only need one of these, but in rewriting this, make sure that the parameters are properly included in the backprop.
  2. Also duplicate code for things like periodising filters which appear in both kymat_code and scattering2d that can be removed.
  3. Consider if we want to keep padding. We don't want reflective padding for CAMELs applications, but it might be worth keeping the code in so the model can be applied to other datasets without too much extra work.
Chris-Pedersen commented 2 years ago

I noticed that both wavelets and psi were being used to store the same information (the wavelet filters in Fourier space). So I removed psi from the code to avoid having these duplicated. Also removed the Kymatio scattering object from the code entrirely. The backend code used to perform the convolutions is now imported directly from torch_backend, and the low pass filter is directly generated using get_phi. We now have no connections to Kymatio, but quite a bit more to do still:

  1. Much of the wavelet parameter tracking code worked with the psi format of the wavelets instead of the wavelets. Given we use wandb to track wavelet parameters now, we could drop this, or update it to work with wavelets instead. NB in some cases the number of filters is hardcoded, so we will also need to generalise this if we want to keep this code.
  2. Need to test the forward passes and backprop to make sure that the model is still working end to end.
  3. Decide what to do with padding, just remove it?
  4. There are some duplicate functions from Kymatio and the PSN paper code (i.e. periodize_filter_fft and the code used to generate the wavelets).
Chris-Pedersen commented 2 years ago

In the previous code, the wavelet filters were registered as pytorch buffers, named tensorN where N was the Nth wavelet. However the wavelets that were passed to the convolution loop was the scattering.psi dictionary, so we had more duplicate wavelets being stored, and the module buffers themselves weren't being used in the forward passes of the model.

I've trimmed this down in the latest commit, where now the wavelet and phi filters are registered as buffers. Unfortunately pytorch doesn't allow you to store buffers as a list, so I've defined two @property decorators to access the wavelet and phi filter banks as lists, so preserving the rest of the syntax of the code for now.

I have tested so far that the tensors in sn_base_model.wavelet and sn_base_model.phi are correctly being moved from cpu to gpu and vice versa, so the buffer set up is working, and have also tested that a forward pass produces no errors. Still need to test the backward passes though, and will do that over the coming few days.

A few more things to take care of before merging:

  1. Integrate or remove the rest of the code in kymat_code.py
  2. We have duplicate functions for periodize_filter_fft, verify they do exactly the same thing and then remove the one that invovles nested loops, and integrate the cleaner one.
Chris-Pedersen commented 2 years ago

The way I was trying to use the wavelets as buffers before wasn't working. Torch doesn't appear to allow you to store buffers in a list, so I tried to store them as indexed objects (i.e. model.scatteringBase.waveletN where N is the index of the wavelet, and then access these using a property decorater. But this was throwing the error:

Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

when training. This happened on the 2nd batch of the first epoch, at the backward() pass after applying zero_grad(). I'm not sure exactly what was causing it, but dropping the property decorator and just storing the wavelets as a normal object fixes this and the training runs fine. However the issue here is that the wavelets and smoothing filters aren't now registered as buffers, and so there are device conflicts. I have to do something else now but will come back to this tomorrow.

Chris-Pedersen commented 2 years ago

Have also noticed that the scatteringBase.wavelets have requires_grad=True. I don't think this should be the case - only the params_filters are needed for autograd, the wavelet "pixel" values themselves shouldn't be included.

Chris-Pedersen commented 2 years ago

Ok getting there. Miles suggested not bothering with buffering the wavelet tensors, and just extending the model.to(device) inherited method to move them when the device changes. This seems like the best solution, so I've implemented that, and cut out almost all of the remaining kymatio code.

Our smoothing filters are currently generated using the kymatio gabor_2d code, and this is now it's only purpose in the code. So next steps are:

  1. Remove this and just have a Gaussian filter generated from an appropriate function
  2. Tidy up the comments and docstrings

Then I think we are good to merge this branch.

Chris-Pedersen commented 2 years ago

Have tested this reproduces old results. Still a bit more tidying up to go but probably a good place for a merge.