Open mmalex opened 7 months ago
How are you installing jax? If you use the standard cuda install (https://github.com/google/jax?tab=readme-ov-file#instructions), that does not use the PJRT C API and should not have this problem.
I'm also working on fixing this in the PJRT C API so we can switch to a cuda plugin that does use the C API, but you shouldn't need to wait for that.
thanks for the reply! i had no idea that was even a thing :) i installed jax on py3.10 via a pip install of jax[cuda12]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html this was 'inherited' without thought from a colleague, zac. i'll ask him more about it.
On Fri, 15 Mar 2024 at 22:14, Skye Wanderman-Milne @.***> wrote:
How are you installing jax? If you use the standard cuda install ( https://github.com/google/jax?tab=readme-ov-file#instructions), that does not use the PJRT C API and should not have this problem.
I'm also working on fixing this in the PJRT C API so we can switch to a cuda plugin that does use the C API, but you shouldn't need to wait for that.
— Reply to this email directly, view it on GitHub https://github.com/openxla/xla/issues/10613#issuecomment-2000572919, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAFOYTYSV5TZDFQ43NKVIT3YYNXC3AVCNFSM6AAAAABEYFIHKCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBQGU3TEOJRHE . You are receiving this because you authored the thread.Message ID: @.***>
Ah ok, [cuda12]
is the plugin install (we kept it simple because we want this to be the one true cuda install once we fix these C API issues). Try:
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
(cuda12_pip
instead of just cuda12
)
Closing this issue. Please reopen if you're still having issues!
EDIT: jk, I don't have permissions to close openxla issues :) Please let me know if there's still an issue, otherwise I think we can go ahead and close it!
hello! I'm trying to write a simple CUDA/Python module (coincidentally using https://github.com/wjakob/nanobind to provide the dlpack integration) and I have a method that returns a cuda-allocated array to jax, via dlpack. however, when jax tries to construct an ndarray from the dlpack object returned by nanobind, we hit an exception
UNIMPLEMENTED: PJRT C API does not support GetDefaultLayout
I traced the flow of code as follows - jax.dlpack callsin its dlpack.py, which seems innocuous enough.
dlpack_managed_tensor_to_buffer
is a wrapper aroundDLPackManagedTensorToBuffer
which is defined in thexla/xla/python/dlpack.cc
file of the openxla project; sadly that function works by validating if the strides you passed in are equal to the default ones, just as a sanity check. how does it work out the default strides? why, it calls (always!) GetDefaultLayout, here https://github.com/openxla/xla/blob/ac8fb1fda904b0283e5926b9758506bba8ce9e0a/xla/python/dlpack.cc#L418-L423 which leads to this line https://github.com/openxla/xla/blob/886a1917270d3edbb3d02f9fe0954736dfede8db/xla/pjrt/pjrt_c_api_client.h#L286-L290 'TODO: implement'. considering that this function is called by the constructor of ndarrays injax.dlpack
, im surprised its not implemented, and/or i haven't found an existing issue for this.im afraid im fairly new to jax, dlpack, and nanobind, so forgive me if I am making some dumb mistake. how would you suggest I proceed? regardless of this particular code path, the problem I am trying to solve is to return a cuda array to jax via dlpack, ideally using nanobind as a wrapper over dlpack. if there's another way to do that, that would work too :)
the TODO is listed to @skye so forgive me for @'ing you.