rdyro / jaxfi-JAXFriendlyInterface

Friendly Interface to JAX, `jaxfi` simplifies JAX interface to replicate PyTorch's ease of use.
MIT License
2 stars 0 forks source link

Question about wrap_torch_fn function #2

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

I notice that you use jax.pure_callback to call the PyTorch function that is being wrapped.

My understanding from https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html is that pure_callback isn't guaranteed to execute in order. Maybe I am misunderstanding what they mean by "guaranteed execution", but is there not a danger that the wrapped pytorch function doesn't execute as expected?

rdyro commented 1 year ago

I believe the execution is guaranteed. The torch function currently executes on the numpy data exported from JAX and so should work as expected as long as JAX correctly exports data for pure_callback which is extensively tested as far as I can tell.

adam-hartshorne commented 1 year ago

Ok cool.

I wonder if it is possible to use the reverse of the jax2torch using dlpack to ensure that data isn't moved off the GPU to CPU?

rdyro commented 1 year ago

definitely, take a look at this possibility:

from __future__ import annotations

import numpy as np
import jax.dlpack
import torch
from torch import Tensor
import torch.utils.dlpack
from jax import Array

def transfer(x: Array | Tensor, via: str = "dlpack", device: str = "cuda"):
    assert via in ("dlpack", "cpu")
    if isinstance(x, Array):
        if via == "dlpack":
            return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
        else:
            return torch.as_tensor(np.array(x), device=device)
    else:
        if via == "dlpack":
            return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
        else:
            return jax.device_put(jax.numpy.array(x.detach().cpu().numpy()), device=jax.devices(device)[0])