kornia / kornia

Geometric Computer Vision Library for Spatial AI
https://kornia.readthedocs.io
Apache License 2.0
9.91k stars 966 forks source link

K.Resize() doesn't work on MPS devices #3026

Open calebrob6 opened 3 weeks ago

calebrob6 commented 3 weeks ago

Describe the bug

K.Resize() doesn't work if the device is torch.device("mps")

 File "/Users/.../miniforge3/envs/ftw4/lib/python3.12/site-packages/kornia/utils/helpers.py", line 232, in _torch_solve_cast
    out = torch.linalg.solve(A.to(torch.float64), B.to(torch.float64))
                             ^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Also mentioned in #1717

Reproduction steps

1. Create `resizer = K.Resize(...).to("mps")`
2. Use the `resizer`

Expected behavior

K.Resize() should resize, falling back to float32 if necessary.

Environment

Unsure -- can find more details if necessary

Additional context

No response

edgarriba commented 3 weeks ago

@calebrob6 are you open to put a PR fixing ? I do not have a device with mps backend