rdyro / jaxfi-JAXFriendlyInterface

Friendly Interface to JAX, `jaxfi` simplifies JAX interface to replicate PyTorch's ease of use.
MIT License
2 stars 0 forks source link

More efficient passing of arrays between jax / pytorch #1

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

I notice in pytorch.py the functionality to convert between pytorch and jax arrays goes via a conversion to the cpu. It is possible to do this without moving the data at all.

def j2t(x_jax, device=None):
    x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
    if device:
        x_torch = x_torch.to(device)
    return x_torch

def t2j(x_torch, device=None):
    x_torch = x_torch.contiguous()  # https://github.com/google/jax/issues/8082
    if device:
        x_torch = x_torch.to(device)
    x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
    return x_jax

More can be found here, https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py https://github.com/facebookresearch/silk/blob/main/lib/utils/jax.py

rdyro commented 1 year ago

dlpack transfer is possible when not under jit, which is the case for both jax2torch and silk (I believe).

The purpose of this package is to allow conversion under jit, which is a harder problem, but I'll look into it, thanks!

rdyro commented 1 year ago

I'm going to leave this issue open until I can figure out how to use dlpack with jit.

adam-hartshorne commented 1 year ago

I have it working under JIT using the code in the facebook repo. One additional change I made to improve performance further with to replace

y_, ctx.fun_vjp = jax.vjp(fn, *args)

with

vjp = jit(lambda *args: jax.vjp(fn, *args))
 y_, ctx.fun_vjp = vjp(*args)
rdyro commented 1 year ago

I can't get the transfer to work under jit in this example:

from __future__ import annotations

import numpy as np
import jax.dlpack
import torch
from torch import Tensor
import torch.utils.dlpack
from jax import Array

def transfer(x: Array | Tensor, via: str = "dlpack", device: str = "cuda"):
    assert via in ("dlpack", "cpu")
    if isinstance(x, Array):
        if via == "dlpack":
            return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
        else:
            return torch.as_tensor(np.array(x), device=device)
    else:
        if via == "dlpack":
            return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
        else:
            return jax.device_put(jax.numpy.array(x.detach().cpu().numpy()), device=jax.devices(device)[0])

def f(x: Array) -> Array:
    device = "cuda"
    return transfer(torch.sum(transfer(x, via="dlpack", device=device)), via="dlpack", device=device)

r = jax.device_put(jax.numpy.ones(1000).astype(jax.numpy.float32), jax.devices("cuda")[0])

try:
    print(f(r))
    print("Untraced JAX function did work.")
except:
    print("Untraced JAX function did NOT work.")

try:
    print(jax.jit(f)(r))
    print("Traced JAX function did work.")
except:
    print("Traced JAX function did NOT work.")

which prints

1000.0
Untraced JAX function did work.
Traced JAX function did NOT work.

Let me know if you have an idea how to somehow call pure_callback (to be able to call pytorch code under JAX jit) with an alternative data serialization (not via CPU transfer, but dlpack instead).

adam-hartshorne commented 1 year ago

You are trying to jit

transfer(torch.sum(transfer(x,

that definitely can't work as is, as it includes a pure call to a torch operator.

When I said I can get jit function to work, I mean I have a pure jax function(s), which is decorated with jit, then wrapped with jax2torch (as per links), then I call that wrapped function within PyTorch code. This works fine, and can pass gradients through in order to run optimisation. I have used this on a fairly complex model and there doesn't seem any obvious performance hit.

For your example above. At the every least, I think you would need to a wrapper which uses jax.pure_callback to "wrap" the PyTorch function called from within Jax (which is then jittable), then perhaps you could wrap this new "jax" function with jax2torch functionality. I sort of hope this might be possible, as long as all the gradients are properly handled, but no idea what the performance overhead might be from all this mixing.

adam-hartshorne commented 1 year ago

Edit - I have just stumbled across this, which might make some things neater / easier

filter_pure_callback - Calls a Python function inside a JIT region. As jax.pure_callback but accepts arbitrary Python objects as inputs and outputs. (Not just JAXable types.) The result of callback(*args, **kwargs), valid for use under JIT.

rdyro commented 1 year ago

I think that's right. I'm trying to dig into how JAX implements pure_callback. The problem with using pure_callback is that both the inputs and outputs will be sent to and from the CPU respectively, so any performance from using dlpack will be lost.

I think filter_pure_callback just filters out non-array-like arguments to convert them to static arguments, so it won't help here as we're fundamentally trying to support numerical data.

adam-hartshorne commented 1 year ago

My apologies, I don't know how I missed that pure_callback executes on CPU, like any Python+NumPy function. So yes, obviously for calling PyTorch from inside JAX that isn't going to help.

Making your own Custom Operator (https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html) looks like the only way to call functions such that data stays on the GPU, and the documentation for that is terrible / looks a nightmare (unlike PyTorch / Tensorflow, where the interface is fairly straightforward).

rdyro commented 1 year ago

I have some work in progress on a no-copy calling of pytorch from jax, include under JIT, you can check it out here:

https://github.com/rdyro/torch2jax

@adam-hartshorne