lxasqjc / Deformation-Segmentation

PyTorch implementation of Learning to Downsample for Segmentation of Ultra-High Resolution Images [ICLR 2022]
Other
41 stars 6 forks source link

Question about training with loss at high resolution #13

Closed mtchiu2 closed 1 year ago

mtchiu2 commented 1 year ago

Hi,

Thank you very much for your work.

I'm trying to train a model on a custom dataset, where I want to keep the segmentation map at the original resolution during loss calculation (rather than downsampling GT as done in the paper). However, I encountered an error in models.py, Line 570-578. I wonder if there is a bug in the code, where the image (x) was upsampled rather than the GT (pred)?

It also seems that enabling loss_at_high_res brings up additional imports (e.g. Interp2D) that are not intended to be used? There are no instructions on how to get them to work in this repository.

As a side question, do you observe any performance differences when computing loss at resolution as opposed to downsampling the GT?

lxasqjc commented 1 year ago

Hi @mtchiu2, thanks for pointing it out! Yes, you are right the input between L. 570-578 should be pred not x, I have corrected it in af9252d. Yet I didn't test hence it may lead to your second issue of missing the Interp2D function.

This idea was implemented in the very early stage of the project but later dropped hence not maintained. We did test this idea and yes if calculate loss at the original high resolution you will get performance improvement, and that could be the potential performance upper bond at giving segmentation backbond.

However, the downsides are: 1) the computation cost of recovering to high resolution to obtain loss is very high; 2) it contradicts our original motivation --- cutting down the computation cost of segmentation of high-resolution images while minimizing performance drop (i.e. achieving the best cost-performance trade-off).

mtchiu2 commented 1 year ago

Thank you very much for the answer.

Is it still possible to get that option to work? I tried enabling it but encountered even more errors. It seems to involve installing the spatial/ package in the repo, which led to some .so libraries that I couldn't get working without instructions.

I would assume the motivation is on efficiency during inference time? In cases where GT is sparse/small, computing loss at low resolution (especially with very high downsampling rates) seems like it would miss would a lot of foreground.

lxasqjc commented 1 year ago

Hi @mtchiu2 , sorry for the very late reply, currently I cannot do it on my own due to other research duties. However .so libraries are likely caused by the incompatible environment, which you may check, sorry can't be more helpful.