Closed Chris-Pedersen closed 2 years ago
I noticed that both wavelets
and psi
were being used to store the same information (the wavelet filters in Fourier space). So I removed psi
from the code to avoid having these duplicated. Also removed the Kymatio scattering
object from the code entrirely. The backend code used to perform the convolutions is now imported directly from torch_backend
, and the low pass filter is directly generated using get_phi
. We now have no connections to Kymatio, but quite a bit more to do still:
psi
format of the wavelets instead of the wavelets
. Given we use wandb
to track wavelet parameters now, we could drop this, or update it to work with wavelets
instead. NB in some cases the number of filters is hardcoded, so we will also need to generalise this if we want to keep this code.periodize_filter_fft
and the code used to generate the wavelets). In the previous code, the wavelet filters were registered as pytorch buffers, named tensorN
where N
was the Nth wavelet. However the wavelets that were passed to the convolution loop was the scattering.psi
dictionary, so we had more duplicate wavelets being stored, and the module buffers themselves weren't being used in the forward passes of the model.
I've trimmed this down in the latest commit, where now the wavelet and phi filters are registered as buffers. Unfortunately pytorch doesn't allow you to store buffers as a list, so I've defined two @property decorators to access the wavelet and phi filter banks as lists, so preserving the rest of the syntax of the code for now.
I have tested so far that the tensors in sn_base_model.wavelet
and sn_base_model.phi
are correctly being moved from cpu to gpu and vice versa, so the buffer set up is working, and have also tested that a forward pass produces no errors. Still need to test the backward passes though, and will do that over the coming few days.
A few more things to take care of before merging:
kymat_code.py
periodize_filter_fft
, verify they do exactly the same thing and then remove the one that invovles nested loops, and integrate the cleaner one.The way I was trying to use the wavelets as buffers before wasn't working. Torch doesn't appear to allow you to store buffers in a list, so I tried to store them as indexed objects (i.e. model.scatteringBase.waveletN
where N
is the index of the wavelet, and then access these using a property decorater. But this was throwing the error:
Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
when training. This happened on the 2nd batch of the first epoch, at the backward()
pass after applying zero_grad()
. I'm not sure exactly what was causing it, but dropping the property decorator and just storing the wavelets as a normal object fixes this and the training runs fine. However the issue here is that the wavelets and smoothing filters aren't now registered as buffers, and so there are device conflicts. I have to do something else now but will come back to this tomorrow.
Have also noticed that the scatteringBase.wavelets
have requires_grad=True
. I don't think this should be the case - only the params_filters
are needed for autograd, the wavelet "pixel" values themselves shouldn't be included.
Ok getting there. Miles suggested not bothering with buffering the wavelet tensors, and just extending the model.to(device)
inherited method to move them when the device changes. This seems like the best solution, so I've implemented that, and cut out almost all of the remaining kymatio code.
Our smoothing filters are currently generated using the kymatio gabor_2d
code, and this is now it's only purpose in the code. So next steps are:
Then I think we are good to merge this branch.
Have tested this reproduces old results. Still a bit more tidying up to go but probably a good place for a merge.
Given we want to build statistical isotropy into the model and have more freedom in the numbers of first and second order filters, we will want more flexibility in generating the wavelets and performing the convolutions than Kymatio is intended for. So in this PR we want to take the key computational parts of Kymatio (i.e. the scattering convolutions) and integrate them into our code, and detatch from any other Kymatio imports. This first commit is a crude first stab at this, but there's a lot more tidying to be done. A few major points:
filter_bank(..)
from Kymatio is called, followed byraw_morlet(..)
from PSN. We should only need one of these, but in rewriting this, make sure that the parameters are properly included in the backprop.kymat_code
andscattering2d
that can be removed.