ivadomed / Counterfactual_synthesis_for_lesion_segmentation

1 stars 0 forks source link

Push input size to 32 x 256 x 256 for diffusion #3

Closed BreziTasbi closed 2 months ago

BreziTasbi commented 4 months ago

The current diffusion training and inference input size is constrained by the available graphics memory. The existing implementation already includes an option to reduce the parameter bit depth from 32 to 16, thereby significantly saving VRAM. This optimization enables straightforward handling of 16 x 256 x 256 images on a 48Gb GPU, but no more due to the highly demanding decoding step. With a 1mm resolution, the entire spinal-cord cannot fit into the frame.

To address this, I propose an unconventional yet effective method to reduce memory cost and increase input size. This involves dividing the result of the diffusion in the latent space before decoding it part by part, and then reassembling these parts to obtain the final image.

(The most recent commit introduces a "decoding_diviser" variable in "config/model/vq_gan_3d.yaml" which controls the number of parts the decoding process will be divided into.)

However, I anticipate that this method may introduce some spatial inconsistencies. Therefore, I plan to assess the extent of these potential side effects as soon as possible.

BreziTasbi commented 4 months ago

The first result is quite dramatically spatialy inconsistent. Even so that i suspect that there is actually a mistake in the way latent space is splitted 37 39

BreziTasbi commented 4 months ago

I modified the decoding step to add a 50% overlap accross consecutives parts and took advantage of this to add a fade between them. The result doesn't show spatial inconsistencies anymore. Yet I don't know if this kind of DIY is legit :

98 97

Now that the the decoding step doesn't bottleneck the input size anymore, i wonder how far we can get with 48GB. Let's test it out : it was quick, 32x256x356 is the new limit ^^

BreziTasbi commented 3 months ago

I had left the good looking images convince me but hadn't assess the quality drawbacks just yet.

Test performed on 788 unseen images (T2W canproco data) cropped to 32128256. I mesured L2 loss between input and ouput of the VQGAN using no split (normal usage), split in 3 images shaped 16128256 with smoothing (second method) and split in 2 images shaped 16128256 without smoothing :

outputL2

I'm quite surprise to witness that smoothing seems to worsen the output quality since the spatial inconsistency is quite dramatic without smoothing. I would be curious to see if another metric more dedicated to assessing spatial consistency would reveal the superiority of smoothing. Nonetheless, the quality worsening stays in the same range with and without splitting.