openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.67k stars 427 forks source link

PJRT C API does not support GetDefaultLayout, however it is called (always) by DLPackManagedTensorToBuffer #10613

Open mmalex opened 7 months ago

mmalex commented 7 months ago

hello! I'm trying to write a simple CUDA/Python module (coincidentally using https://github.com/wjakob/nanobind to provide the dlpack integration) and I have a method that returns a cuda-allocated array to jax, via dlpack. however, when jax tries to construct an ndarray from the dlpack object returned by nanobind, we hit an exception UNIMPLEMENTED: PJRT C API does not support GetDefaultLayout I traced the flow of code as follows - jax.dlpack calls

    return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
        dlpack, cpu_backend, gpu_backend))

in its dlpack.py, which seems innocuous enough. dlpack_managed_tensor_to_buffer is a wrapper around DLPackManagedTensorToBuffer which is defined in the xla/xla/python/dlpack.cc file of the openxla project; sadly that function works by validating if the strides you passed in are equal to the default ones, just as a sanity check. how does it work out the default strides? why, it calls (always!) GetDefaultLayout, here https://github.com/openxla/xla/blob/ac8fb1fda904b0283e5926b9758506bba8ce9e0a/xla/python/dlpack.cc#L418-L423 which leads to this line https://github.com/openxla/xla/blob/886a1917270d3edbb3d02f9fe0954736dfede8db/xla/pjrt/pjrt_c_api_client.h#L286-L290 'TODO: implement'. considering that this function is called by the constructor of ndarrays in jax.dlpack, im surprised its not implemented, and/or i haven't found an existing issue for this.

im afraid im fairly new to jax, dlpack, and nanobind, so forgive me if I am making some dumb mistake. how would you suggest I proceed? regardless of this particular code path, the problem I am trying to solve is to return a cuda array to jax via dlpack, ideally using nanobind as a wrapper over dlpack. if there's another way to do that, that would work too :)

the TODO is listed to @skye so forgive me for @'ing you.

skye commented 7 months ago

How are you installing jax? If you use the standard cuda install (https://github.com/google/jax?tab=readme-ov-file#instructions), that does not use the PJRT C API and should not have this problem.

I'm also working on fixing this in the PJRT C API so we can switch to a cuda plugin that does use the C API, but you shouldn't need to wait for that.

mmalex commented 7 months ago

thanks for the reply! i had no idea that was even a thing :) i installed jax on py3.10 via a pip install of jax[cuda12]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html this was 'inherited' without thought from a colleague, zac. i'll ask him more about it.

On Fri, 15 Mar 2024 at 22:14, Skye Wanderman-Milne @.***> wrote:

How are you installing jax? If you use the standard cuda install ( https://github.com/google/jax?tab=readme-ov-file#instructions), that does not use the PJRT C API and should not have this problem.

I'm also working on fixing this in the PJRT C API so we can switch to a cuda plugin that does use the C API, but you shouldn't need to wait for that.

— Reply to this email directly, view it on GitHub https://github.com/openxla/xla/issues/10613#issuecomment-2000572919, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAFOYTYSV5TZDFQ43NKVIT3YYNXC3AVCNFSM6AAAAABEYFIHKCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBQGU3TEOJRHE . You are receiving this because you authored the thread.Message ID: @.***>

skye commented 7 months ago

Ah ok, [cuda12] is the plugin install (we kept it simple because we want this to be the one true cuda install once we fix these C API issues). Try:

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

(cuda12_pip instead of just cuda12)

skye commented 7 months ago

Closing this issue. Please reopen if you're still having issues!

EDIT: jk, I don't have permissions to close openxla issues :) Please let me know if there's still an issue, otherwise I think we can go ahead and close it!