balbasty / nitorch

Neuroimaging in PyTorch
Other
86 stars 14 forks source link

Incompatibility with newer torch versions, e.g. in current MONAI docker images #77

Closed nvahmadi closed 1 year ago

nvahmadi commented 1 year ago

Hi, as discussed with @brudfors , I was just trying to run the demo_affine_reg.ipynb notebook, in a current MONAI container (e.g. projectmonai/monailabel:0.6.0). Unfortunately, the torch syntax for solving A\b has changed. Already the first nitorch code cell:

pth_mris_ra = realign(pth_mris, prefix='ra_')

yields the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [9], line 1
----> 1 pth_mris_ra = realign(pth_mris, prefix='ra_')

Cell In [4], line 24, in realign(pths, prefix, odir, t_std, r_std)
     22 M = dat.affine
     23 R = _expm(q[n, ...], basis=B)
---> 24 M = M.solve(R)[0]
     25 # Modify affine in header
     26 dat.set_metadata(affine=M)

File /opt/conda/lib/python3.8/site-packages/torch/_tensor.py:639, in Tensor.solve(self, other)
    636 def solve(self, other):
    637     from ._linalg_utils import solve
--> 639     return solve(self, other)

File /opt/conda/lib/python3.8/site-packages/torch/_linalg_utils.py:102, in solve(input, A, out)
    101 def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
--> 102     raise RuntimeError(
    103         "This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.solve` function instead.",
    104     )

RuntimeError: This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.solve` function instead.

The same error throws for all following nitorch code cells, e.g. for the command:

dat_aligned = affine_align(pth_mris_ra, device=device, cost_fun=cost_fun, samp=samp)[0]

Would be great if nitorch could use the updated syntax, otherwise I'd need to search for the last torch/MONAI version that would work. Thanks in advance! :)

brudfors commented 1 year ago

Thanks @nvahmadi!

@balbasty shall we perhaps replace any use of .solve() with the NITorch alternative lmdiv?

balbasty commented 1 year ago

That would probably be the best alternative. lmdiv uses solve (or linalg.solve) under the hood but there’s a test to use the correct function for each torch version. I kind of like that most of the code base works across many torch versions so I’d rather go that way.

brudfors commented 1 year ago

Okay, I will do that then.

brudfors commented 1 year ago

@nvahmadi can you please test the solve2lmdiv branch and see if it resolves your issues? Thanks!

nvahmadi commented 1 year ago

Sure! Does it need a rebuild for the C/cuda extensions?

brudfors commented 1 year ago

That should not be needed.

nvahmadi commented 1 year ago

I tried it - it works! :) Excellent, many thanks!