Closed adam-hartshorne closed 1 year ago
dlpack transfer is possible when not under jit, which is the case for both jax2torch
and silk
(I believe).
The purpose of this package is to allow conversion under jit, which is a harder problem, but I'll look into it, thanks!
I'm going to leave this issue open until I can figure out how to use dlpack with jit.
I have it working under JIT using the code in the facebook repo. One additional change I made to improve performance further with to replace
y_, ctx.fun_vjp = jax.vjp(fn, *args)
with
vjp = jit(lambda *args: jax.vjp(fn, *args))
y_, ctx.fun_vjp = vjp(*args)
I can't get the transfer to work under jit in this example:
from __future__ import annotations
import numpy as np
import jax.dlpack
import torch
from torch import Tensor
import torch.utils.dlpack
from jax import Array
def transfer(x: Array | Tensor, via: str = "dlpack", device: str = "cuda"):
assert via in ("dlpack", "cpu")
if isinstance(x, Array):
if via == "dlpack":
return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
else:
return torch.as_tensor(np.array(x), device=device)
else:
if via == "dlpack":
return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
else:
return jax.device_put(jax.numpy.array(x.detach().cpu().numpy()), device=jax.devices(device)[0])
def f(x: Array) -> Array:
device = "cuda"
return transfer(torch.sum(transfer(x, via="dlpack", device=device)), via="dlpack", device=device)
r = jax.device_put(jax.numpy.ones(1000).astype(jax.numpy.float32), jax.devices("cuda")[0])
try:
print(f(r))
print("Untraced JAX function did work.")
except:
print("Untraced JAX function did NOT work.")
try:
print(jax.jit(f)(r))
print("Traced JAX function did work.")
except:
print("Traced JAX function did NOT work.")
which prints
1000.0
Untraced JAX function did work.
Traced JAX function did NOT work.
Let me know if you have an idea how to somehow call pure_callback
(to be able to call pytorch code under JAX jit) with an alternative data serialization (not via CPU transfer, but dlpack instead).
You are trying to jit
transfer(torch.sum(transfer(x,
that definitely can't work as is, as it includes a pure call to a torch operator.
When I said I can get jit function to work, I mean I have a pure jax function(s), which is decorated with jit, then wrapped with jax2torch (as per links), then I call that wrapped function within PyTorch code. This works fine, and can pass gradients through in order to run optimisation. I have used this on a fairly complex model and there doesn't seem any obvious performance hit.
For your example above. At the every least, I think you would need to a wrapper which uses jax.pure_callback to "wrap" the PyTorch function called from within Jax (which is then jittable), then perhaps you could wrap this new "jax" function with jax2torch functionality. I sort of hope this might be possible, as long as all the gradients are properly handled, but no idea what the performance overhead might be from all this mixing.
Edit - I have just stumbled across this, which might make some things neater / easier
filter_pure_callback - Calls a Python function inside a JIT region. As jax.pure_callback but accepts arbitrary Python objects as inputs and outputs. (Not just JAXable types.) The result of callback(*args, **kwargs), valid for use under JIT.
I think that's right. I'm trying to dig into how JAX implements pure_callback
. The problem with using pure_callback
is that both the inputs and outputs will be sent to and from the CPU respectively, so any performance from using dlpack
will be lost.
I think filter_pure_callback
just filters out non-array-like arguments to convert them to static arguments, so it won't help here as we're fundamentally trying to support numerical data.
My apologies, I don't know how I missed that pure_callback executes on CPU, like any Python+NumPy function. So yes, obviously for calling PyTorch from inside JAX that isn't going to help.
Making your own Custom Operator (https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html) looks like the only way to call functions such that data stays on the GPU, and the documentation for that is terrible / looks a nightmare (unlike PyTorch / Tensorflow, where the interface is fairly straightforward).
I have some work in progress on a no-copy calling of pytorch from jax, include under JIT, you can check it out here:
https://github.com/rdyro/torch2jax
@adam-hartshorne
I notice in pytorch.py the functionality to convert between pytorch and jax arrays goes via a conversion to the cpu. It is possible to do this without moving the data at all.
More can be found here, https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py https://github.com/facebookresearch/silk/blob/main/lib/utils/jax.py