teamtomo / membrain-seg

membrane segmentation in 3D for cryo-ET
Other
47 stars 12 forks source link

On the fly rescaling (GPU) #64

Closed LorenzLamm closed 1 month ago

LorenzLamm commented 3 months ago

Regarding the discussion in https://github.com/teamtomo/membrain-seg/issues/55, I added the option to perform rescaling of the inference patches on the fly and on GPU.

With this, the user can simply input tomograms in any pixel size, the model will perform sliding window inference, and rescale each patch individually to e.g. 10A. The output is again a segmentation in the original pixel size.

This is really fast (difference depending on tomo pixel size, but e.g. 150 vs 155sec inference time) and did not result in big changes in segmentation quality.

How it's done internally:

  1. Sliding Window inferer window size is adjusted s.t. after rescaling this sliding window (e.g. to 10A), the rescaled window has the target size (default 160)
  2. Rescaling is performed within the model itself (as preprocessing / postprocessing options). This way, we do not need to touch the SWInferer class, which seemed convenient to me. Workflow: rescale patch to 160^3 --> model prediction --> rescale back to original shape
  3. The SWInferer stitches together the patches in the original dimensions

@alisterburt This is not using any libtilt functionality yet (I only found fourier cropping / padding for 2D, but I guess this could easily be extended). I guess the rescaling itself could be done more sophisticated, but but sure if necessary for this task? @rdrighetto

Happy for any feedback :)

uermel commented 1 month ago

Hi @LorenzLamm, great PR, this does improve performance for segmentation in some initial tests I did.

I have one request: It would be nice set the torch device for the rescaling functions to the model's device. At the moment, the fourier_cropping_torch and fourier_extend_torch set the device to an unspecified cuda-device, which causes exceptions when trying to run the inference on a specific GPU. I've modified this for our inference wrapper, but this would be good to exist upstream.

It's a simple change (this is in our wrapper):

codecov-commenter commented 1 month ago

Codecov Report

Attention: Patch coverage is 0% with 92 lines in your changes are missing coverage. Please review.

Project coverage is 0.00%. Comparing base (bea5ae0) to head (ef62e08). Report is 6 commits behind head on main.

Files Patch % Lines
...mbrain_seg/segmentation/networks/inference_unet.py 0.00% 41 Missing :warning:
..._preprocessing/matching_utils/px_matching_utils.py 0.00% 27 Missing :warning:
src/membrain_seg/segmentation/segment.py 0.00% 24 Missing :warning:

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #64 +/- ## ====================================== Coverage 0.00% 0.00% ====================================== Files 40 46 +6 Lines 1411 1631 +220 ====================================== - Misses 1411 1631 +220 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

LorenzLamm commented 1 month ago

Hey @uermel ,

Thanks a lot for your feedback on this. Sorry for the late reply -- vacation kept me from working on this :) I have incorporated your suggestions and the device is now the same as the model's device.

This functionality seems to work -- I think it's ready to merge into main.