Closed adam-hartshorne closed 1 year ago
I believe the execution is guaranteed. The torch function currently executes on the numpy data exported from JAX and so should work as expected as long as JAX correctly exports data for pure_callback
which is extensively tested as far as I can tell.
Ok cool.
I wonder if it is possible to use the reverse of the jax2torch using dlpack to ensure that data isn't moved off the GPU to CPU?
definitely, take a look at this possibility:
from __future__ import annotations
import numpy as np
import jax.dlpack
import torch
from torch import Tensor
import torch.utils.dlpack
from jax import Array
def transfer(x: Array | Tensor, via: str = "dlpack", device: str = "cuda"):
assert via in ("dlpack", "cpu")
if isinstance(x, Array):
if via == "dlpack":
return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
else:
return torch.as_tensor(np.array(x), device=device)
else:
if via == "dlpack":
return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
else:
return jax.device_put(jax.numpy.array(x.detach().cpu().numpy()), device=jax.devices(device)[0])
I notice that you use jax.pure_callback to call the PyTorch function that is being wrapped.
My understanding from https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html is that pure_callback isn't guaranteed to execute in order. Maybe I am misunderstanding what they mean by "guaranteed execution", but is there not a danger that the wrapped pytorch function doesn't execute as expected?