Open ricardoV94 opened 1 month ago
I would really like to work on this if possible, it burned me a few times
I would really like to work on this if possible, it burned me a few times
Of course :)
@ricardoV94 ; i'm wondering if this has overlap with the issue I found when messing around with pymc + py[torch|tensor]: #1065. I guess what I'm wondering is should the linker be smart enough to know when to do
result.detach().numpy()
Then the issue with pymc should un theory be solved
The problem is that fails when the data is on the gpu. Is there a cheap way to know when it is and whet it's not? Just wrap it in a try/except?
Yea, x.device
gives you the current location of the tensor. As long as we check cpu Its fairly straightforward (gpu device names vary)
Yea,
x.device
gives you the current location of the tensor. As long as we check cpu Its fairly straightforward (gpu device names vary)
Wanna try that? It's still suboptimal to always force transfer but probably fine for a rough use of the backend. We may allow user control with custom linker settings in the future
We would combine this with the suggestion you had earlier as well?
Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.
We would combine this with the suggestion you had earlier as well?
Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.
Let's skip that idea of the updates for now and force everything to be numpy once it's out. Otherwise you'll have the same sort of problems you saw in your PyMC tests
This whole thing (i.e., calling
out.cpu()
) is suboptimal. I think we don't need it for JAX (which returns JAX arrays/ not numpy arrays), becausenp.asarray
works with it, and I guess it doesn't work for torch tensors.https://github.com/pymc-devs/pytensor/blob/7b13a955daba591b5af5c6d09e9ef4095b465890/pytensor/link/pytorch/linker.py#L16
This should only be needed for updated shared variables where we have to convert to a common type as they could be used in multiple functions with distinct backends.
Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by
Function
.https://github.com/pymc-devs/pytensor/blob/7b13a955daba591b5af5c6d09e9ef4095b465890/pytensor/compile/function/types.py#L1009-L1017
_Originally posted by @ricardoV94 in https://github.com/pymc-devs/pytensor/pull/1032#discussion_r1821221676_