pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
370 stars 109 forks source link

Don't force `.cpu()` on all PyTorch outputs #1052

Open ricardoV94 opened 1 month ago

ricardoV94 commented 1 month ago

This whole thing (i.e., calling out.cpu()) is suboptimal. I think we don't need it for JAX (which returns JAX arrays/ not numpy arrays), because np.asarray works with it, and I guess it doesn't work for torch tensors.

https://github.com/pymc-devs/pytensor/blob/7b13a955daba591b5af5c6d09e9ef4095b465890/pytensor/link/pytorch/linker.py#L16

This should only be needed for updated shared variables where we have to convert to a common type as they could be used in multiple functions with distinct backends.

Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.

https://github.com/pymc-devs/pytensor/blob/7b13a955daba591b5af5c6d09e9ef4095b465890/pytensor/compile/function/types.py#L1009-L1017

_Originally posted by @ricardoV94 in https://github.com/pymc-devs/pytensor/pull/1032#discussion_r1821221676_

Ch0ronomato commented 1 month ago

I would really like to work on this if possible, it burned me a few times

ricardoV94 commented 1 month ago

I would really like to work on this if possible, it burned me a few times

Of course :)

Ch0ronomato commented 3 weeks ago

@ricardoV94 ; i'm wondering if this has overlap with the issue I found when messing around with pymc + py[torch|tensor]: #1065. I guess what I'm wondering is should the linker be smart enough to know when to do

result.detach().numpy()

Then the issue with pymc should un theory be solved

ricardoV94 commented 3 weeks ago

The problem is that fails when the data is on the gpu. Is there a cheap way to know when it is and whet it's not? Just wrap it in a try/except?

Ch0ronomato commented 3 weeks ago

Yea, x.device gives you the current location of the tensor. As long as we check cpu Its fairly straightforward (gpu device names vary)

ricardoV94 commented 3 weeks ago

Yea, x.device gives you the current location of the tensor. As long as we check cpu Its fairly straightforward (gpu device names vary)

Wanna try that? It's still suboptimal to always force transfer but probably fine for a rough use of the backend. We may allow user control with custom linker settings in the future

Ch0ronomato commented 3 weeks ago

We would combine this with the suggestion you had earlier as well?

Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.

ricardoV94 commented 3 weeks ago

We would combine this with the suggestion you had earlier as well?

Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.

Let's skip that idea of the updates for now and force everything to be numpy once it's out. Otherwise you'll have the same sort of problems you saw in your PyMC tests