teamtomo / membrain-seg

membrane segmentation in 3D for cryo-ET
Other
48 stars 12 forks source link

Preprocessing #8

Closed LorenzLamm closed 1 year ago

LorenzLamm commented 1 year ago

This branch implements the

  1. pixel size matching: Fourier cropping / Fourier extension to achieve the specified tomogram pixel size. For both cropping and extension, an ellipsoid mask with cosine decay to zero is applied to avoid artifacts.
  2. Spectral matching: This was adapted from the implementation of DeePict (https://github.com/ZauggGroup/DeePiCt/tree/main/spectrum_filter). I adjusted some details to avoid artifacts and division by values close to zero.

Let me know what you think :)

alisterburt commented 1 year ago

to be clear - the path forward here is to merge and iterate! 🙂

LorenzLamm commented 1 year ago

hey @LorenzLamm

First - this is really awesome, I'm glad we have this functionality in here. The PR is too big for comments on everything individually to not be a slow, frustrating experience to iterate on - instead I suggest we merge this and I will provide some overarching comments that can guide improving things as we move forward

  • x/y/z ordering of images is a little strange, could you explain what's going on there? :)
  • matching utils should probably be a subpackage of the preprocessing package, if writing this myself I would probably the following organisation
  • preprocessing
    • pixel_size_matching
      • _cli.py
      • match_pixel_size.py
    • amplitude_spectrum_matching
      • _cli.py
      • match_amplitude_spectrum.py

Rather than .py scripts which have to be located/added to path/executed, you can install scripts during package installation automatically with the project.scripts block in the pyproject.toml file - here is a PR to a different project where I template/discuss this for someone else bbarad/ETSegTools#1

Discussed in the PR above, I really like using Typer for turning a simple type annotated function into a script which can be executed from the command line - worth trying!

In general it would be great to have more explicit function names e.g. radial_average_3d rather than rad_avg

Does the spectrum matching take a while because of the large fft it computes or is it no big deal? If so we might consider calculating the sum of spectra over a number of smaller 3D patches to do the estimation - this can also be used to increase signal by taking overlapping patches

Thanks a lot for your suggestions. I'll definitely try to implement them in the next iteration. The project.scripts block and the Typer function for the command line interface sounds like they can make the whole package much more convenient to use! :)

Regarding the timing of the FFT: It does indeed take a while to compute the FFTs for the large tomograms. So you would propose a sliding window approach for extraction / matching of the frequencies? Since FFT scales with O(n*log(n)), it should be more efficient to compute a lot of smaller FFTs than computing the FFt of the entire volume, right? But is the transform then still roughly equivalent? I guess I'll do some experiments on this! :)

alisterburt commented 1 year ago

Regarding the timing of the FFT: It does indeed take a while to compute the FFTs for the large tomograms. So you would propose a sliding window approach for extraction / matching of the frequencies? Since FFT scales with O(n*log(n)), it should be more efficient to compute a lot of smaller FFTs than computing the FFt of the entire volume, right? But is the transform then still roughly equivalent?

I wouldn't match on a window directly, I would average the FFTs over the sliding windows then do my spectrum matching on that average spectrum - dealing with the smaller spectrum should be a little easier and if needed the FFTs of the windows could be evaluated in parallel.

A quick look at complexity suggests batching should be quicker (n*log(n/8)), actual tests don't seem to show a huge benefit there...

a = torch.rand((256, 256, 256))
b = torch.rand((8, 128, 128, 128))
%timeit torch.fft.fftn(a, dim=(-3, -2, -1))
%timeit torch.fft.fftn(b, dim=(-3, -2, -1))
148 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
117 ms ± 583 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)