fepegar / torchio

Medical imaging toolkit for deep learning
https://torchio.org
Apache License 2.0
2.07k stars 241 forks source link

interpolated resampling for label #342

Closed romainVala closed 3 years ago

romainVala commented 3 years ago

🚀 Feature Let the user choose the label interpolation, and propose a more efficient method than the current default (nearest neighbor) I do not remember where In which article I found the idea (one of billot I think) but it is simple To resample a 3D label map :

  1. transform it to a 4D one hot encode
  2. resample each 3D volume, with any interpolation you want (default trilinear)
  3. take the argmax across the label to reconstruct a resampled 3D label map

Motivation

Reduce the label perturbation induce by spatial transformations. I would like to quantify the difference comparing different interpolation strategies

Pitch

I am not sure if this is the only choice, but I thing one should add an interpolation_methode parameter for label only. (this was previously implicit nearest neighbor). There will then be 2 interpolation keyword, but I thing it makes sense

Alternatives

make it manually by importing the the onehot encoding version, but too much etra work

Additional context i think it is important for the inverse transform problem: https://github.com/fepegar/torchio/issues/299

fepegar commented 3 years ago

Hi, @romainVala

I don't think it's more efficient, but I've also heard this technique (from Eric Kerfoot).

I guess the affected transforms would Resample, RandomAffine and RandomElasticDeformation, right?

You're welcome to submit a PR!

Here's some code I just used for testing, in case it helps:

import torch
import torchio as tio

t = tio.RandomAffine(seed=423507)

sub = tio.datasets.ICBM2009CNonlinearSymmetric()
tissues = sub.tissues.data
background = 1 - tissues.sum(dim=0)
tissues = torch.cat((background[None], tissues))
label = tio.LabelMap(tensor=tissues.argmax(dim=0, keepdim=True), affine=sub.tissues.affine)

one_hot = torch.nn.functional.one_hot(label.data[0].long()).permute(3, 0, 1, 2)
one_hot_img = tio.ScalarImage(tensor=one_hot, affine=label.affine)

new = tio.Subject(t1=sub.t1, tissues=label, one_hot=one_hot_img)

transformed = t(new)
from_one_hot = transformed.one_hot.data.argmax(dim=0, keepdim=True)
from_one_hot_img = tio.LabelMap(tensor=from_one_hot, affine=transformed.one_hot.affine)

final = tio.Subject(
    t1=transformed.t1,
    tissues=transformed.tissues,
    tissues_one_hot=from_one_hot_img,
)

final.plot()
romainVala commented 3 years ago

if it does not make a big difference, may be it not worth it ... (since it may be very time consuming, if a lot of labels)

Thanks for the code, I will start testing with that, Do you know if it has been quantify ? I would like is to compare dice loss when performing 2 affine to come back at the initial location and to be able to compare with original labels

but then I miss a the exact inverse affine (as mention in issue #299) may be I'll wait for this to be solved ...

fepegar commented 3 years ago

if it does not make a big difference, may be it not worth it ... (since it may be very time consuming, if a lot of labels)

Exactly. I started testing with FPG's seg, but it has too many labels.

Do you know if it has been quantify ?

No idea. We can ask Benjamin Billot, if you say they've used this.

I would like is to compare dice loss when performing 2 affine to come back at the initial location and to be able to compare with original labels

Yeah, makes sense.

but then I miss a the exact inverse affine (as mention in issue #299) may be I'll wait for this to be solved ...

I think you can still try this with SimpleITK or SPM or whatever. Remember SPM? 😬

romainVala commented 3 years ago

SPM, yes I do remember, and still use it unfortunately ... but to do quick test, I end up hacking the random Affine, adding these 3 lines in apply_affine_transform

        transform_inv = transform.GetInverse()
        resampler.SetTransform(transform_inv)
        resampled = resampler.Execute(resampled)

So you are rigth, it does not make a big difference, but still a small one I compute the dice loss (1-dice)*100 between the ref and the twice transformed label or onehot (so a dice loss of 0.5 correspond to a dice of 99.5). I tested 100 random affine

dice loss ref label : 0.51 +- 0.24 [min max] [0.065 1.2] dice loss ref one hot : 0.32 +- 0.053 [min max] [0.23 0.46] dice loss label one hot : 0.73 +- 0.24 [min max] [0.31 1.4]

so nearest neighbor give more variance which make sense, and interpolation with onehot is more stable, et leads to less diff as expected, but ok this is very small, so I guess it can be neglected

note sure i worth a PR, given the extent in computation time needed ...

fepegar commented 3 years ago

I think you can fairly say that there's pretty much no difference if the Dice score is between 99.3 and 100!

romainVala commented 3 years ago

you are biased by working too much on real data ! (with bad dice score)

I would be careful, to conclude, there is no difference, mainly because dice score is a global metrique

but wait, this was on your example which include 3 tissue, I took then other data with subcortical structure, (which are smaller), and the story looks different

dice loss ref label : 5.5 +- 2.8 [min max] [0.5 1.2e+01] dice loss ref one hot : 0.16 +- 0.049 [min max] [0.084 0.25]

fepegar commented 3 years ago

Hahah shame on me for working with real data!

Ok it might be interesting to have an implementation, then. If your team wants to submit I PR, I'll help.

romainVala commented 3 years ago

Ok, I was too quick, to publish the results, those are wrong: I get errors, due to the default pad value which add values in the wrong class. with random affine you loose the border, which then will influence the computed dice score

I tried different version, of default_pading, but could not find a proper one, taking the minimum is ok for label, (if the background has 0 label value) but for one hot, it works for all the tissue except the background (which should take the max instead of the mean)

I guess I will have to go with pad and crop to get good estimate, ...

romainVala commented 3 years ago

I tried to implement quickly the same strategy for elastic deformation,but when I try

resampler.GetTransform().GetInverse() or directly bspline_transform.GetInverse() I get the error sitk::ERROR: Unable to create inverse!

any idea ?

romainVala commented 3 years ago

ok, it is not implemented in stik, here are a suggestion https://discourse.itk.org/t/inverse-of-bspline-transform/496/2 but since I am not very familiar with stik, help would be appreciate ...

fepegar commented 3 years ago

About inverting the transform (should this be in this issue?).

It seems doable. Hopefully it'll be easier soon:

import torchio as tio
import SimpleITK as sitk

colin = tio.datasets.Colin27()
transform = tio.RandomElasticDeformation()
transformed = transform(colin)
grid = -transformed.history[0][1]['coarse_grid']
sitk_image = transformed.t1.as_sitk()
sitk_transform = transform.get_bspline_transform(sitk_image, transform.num_control_points, grid)
resampled_back = sitk.Resample(sitk_image, sitk_transform)
sitk.WriteImage(resampled_back, '/tmp/back.nii')

The idea is inverting the red arrows shown in this gist: https://gist.github.com/fepegar/b723d15de620cd2a3a4dbd71e491b59d

romainVala commented 3 years ago

About inverting the transform (should this be in this issue?).

sorry I answer there #299

fepegar commented 3 years ago

According to the comments, I think this can be closed after #353. @romainVala, feel free to reopen if needed.