kmheckel / spyx

Spyx: Spiking Neural Networks in JAX
https://spyx.readthedocs.io/en/latest/
MIT License
98 stars 11 forks source link

LLVM ERROR: mma16816 data type not supported #31

Closed neworderofjamie closed 2 weeks ago

neworderofjamie commented 1 month ago

I'm trying to run the SHD_jax.py model from the Open Neuromorphic blog and get this error. GPU is an Nvidia RTX A5000, running Driver Version: 555.42.06 with CUDA Version: 12.5. I installed JAX with pip install -U "jax[cuda12]" and spyx with pip install "spyx[data]".

kmheckel commented 1 month ago

Hi Jamie, I'll try to look into this in the next couple days and get back to you - currently at a conference.

On Fri, Jul 26, 2024, 12:38 neworderofjamie @.***> wrote:

I'm trying to run the SHD_jax.py model from the Open Neuromorphic blog and get this error. GPU is an Nvidia RTX A5000, running Driver Version: 555.42.06 with CUDA Version: 12.5. I installed JAX with pip install -U "jax[cuda12]" and spyx with pip install "spyx[data]".

— Reply to this email directly, view it on GitHub https://github.com/kmheckel/spyx/issues/31, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMG7YKG7L2NNAIGC3JTQ4OLZOIRMFAVCNFSM6AAAAABLQITADOVHI2DSMVQWIX3LMV43ASLTON2WKOZSGQZTCOJVGE3TQNY . You are receiving this because you are subscribed to this thread.Message ID: @.***>

neworderofjamie commented 1 month ago

Wonderful, let me know if you need any more information! Enjoy your conference

neworderofjamie commented 1 month ago

It seems like mma16816 refers to a particular shape of "matrix fragment" https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-shape for sending to tensor cores. I don't really know very much about this but the problem seems to occur when XLA compiling scan (with JAX_LOG_COMPILES=1):

Finished tracing + transforming equal for pjit in 0.0010695457458496094 sec
Finished tracing + transforming _broadcast_arrays for pjit in 0.0008301734924316406 sec
Finished tracing + transforming _where for pjit in 0.002557516098022461 sec
Finished tracing + transforming remainder for pjit in 0.006844043731689453 sec
Finished tracing + transforming not_equal for pjit in 0.0008444786071777344 sec
Finished tracing + transforming bitwise_and for pjit in 0.0007028579711914062 sec
Finished tracing + transforming left_shift for pjit in 0.0010814666748046875 sec
Finished tracing + transforming swapaxes for pjit in 0.0005633831024169922 sec
Finished tracing + transforming bitwise_and for pjit in 0.0006837844848632812 sec
Finished tracing + transforming greater for pjit in 0.0008101463317871094 sec
Finished tracing + transforming swapaxes for pjit in 0.0004763603210449219 sec
Finished tracing + transforming unpackbits for pjit in 0.009809017181396484 sec
Finished tracing + transforming dot for pjit in 0.0012443065643310547 sec
Finished tracing + transforming wrapped_fun for pjit in 0.002126932144165039 sec
Finished tracing + transforming wrapped_fun for pjit in 0.002239704132080078 sec
Finished tracing + transforming _reduce_sum for pjit in 0.0004973411560058594 sec
Finished tracing + transforming equal for pjit in 0.0003476142883300781 sec
Finished tracing + transforming _one_hot for pjit in 0.0012297630310058594 sec
Finished tracing + transforming multiply for pjit in 0.0003788471221923828 sec
Finished tracing + transforming add for pjit in 0.00038743019104003906 sec
Finished tracing + transforming _reduce_max for pjit in 0.0007457733154296875 sec
Finished tracing + transforming subtract for pjit in 0.0003197193145751953 sec
Finished tracing + transforming exp for pjit in 0.0002810955047607422 sec
Finished tracing + transforming _reduce_sum for pjit in 0.0004992485046386719 sec
Finished tracing + transforming log for pjit in 0.0002684593200683594 sec
Finished tracing + transforming log_softmax for pjit in 0.004455089569091797 sec
Finished tracing + transforming multiply for pjit in 0.0003139972686767578 sec
Finished tracing + transforming _reduce_sum for pjit in 0.0004935264587402344 sec
Finished tracing + transforming negative for pjit in 0.00024700164794921875 sec
Finished tracing + transforming _reduce_sum for pjit in 0.00037789344787597656 sec
Finished tracing + transforming _mean for pjit in 0.0012240409851074219 sec
Finished tracing + transforming net_eval for pjit in 1.7025067806243896 sec
Finished tracing + transforming multiply for pjit in 0.0003402233123779297 sec
Finished tracing + transforming add for pjit in 0.00036215782165527344 sec
Finished tracing + transforming true_divide for pjit in 0.0003590583801269531 sec
Finished tracing + transforming multiply for pjit in 0.00028133392333984375 sec
Finished tracing + transforming multiply for pjit in 0.0004260540008544922 sec
Finished tracing + transforming add for pjit in 0.00032019615173339844 sec
Finished tracing + transforming multiply for pjit in 0.0003294944763183594 sec
Finished tracing + transforming add for pjit in 0.0002956390380859375 sec
Finished tracing + transforming multiply for pjit in 0.00032901763916015625 sec
Finished tracing + transforming add for pjit in 0.00039267539978027344 sec
Finished tracing + transforming multiply for pjit in 0.00033593177795410156 sec
Finished tracing + transforming add for pjit in 0.0002865791320800781 sec
Finished tracing + transforming less for pjit in 0.00032782554626464844 sec
Finished tracing + transforming add for pjit in 0.0002799034118652344 sec
Finished tracing + transforming _where for pjit in 0.0005338191986083984 sec
Finished tracing + transforming _power for pjit in 0.0004990100860595703 sec
Finished tracing + transforming subtract for pjit in 0.0003542900085449219 sec
Finished tracing + transforming true_divide for pjit in 0.000286102294921875 sec
Finished tracing + transforming true_divide for pjit in 0.0002815723419189453 sec
Finished tracing + transforming true_divide for pjit in 0.0002815723419189453 sec
Finished tracing + transforming true_divide for pjit in 0.0002837181091308594 sec
Finished tracing + transforming tree_bias_correction for pjit in 0.005155086517333984 sec
Finished tracing + transforming add for pjit in 0.00044655799865722656 sec
Finished tracing + transforming sqrt for pjit in 0.00024175643920898438 sec
Finished tracing + transforming true_divide for pjit in 0.00029349327087402344 sec
Finished tracing + transforming add for pjit in 0.0003306865692138672 sec
Finished tracing + transforming sqrt for pjit in 0.00024080276489257812 sec
Finished tracing + transforming true_divide for pjit in 0.0002849102020263672 sec
Finished tracing + transforming add for pjit in 0.00043964385986328125 sec
Finished tracing + transforming sqrt for pjit in 0.0002624988555908203 sec
Finished tracing + transforming true_divide for pjit in 0.0002868175506591797 sec
Finished tracing + transforming add for pjit in 0.0003349781036376953 sec
Finished tracing + transforming sqrt for pjit in 0.00024008750915527344 sec
Finished tracing + transforming true_divide for pjit in 0.0002853870391845703 sec
Finished tracing + transforming train_step for pjit in 7.95122504234314 sec
Finished tracing + transforming _reduce_sum for pjit in 0.0004010200500488281 sec
Finished tracing + transforming _mean for pjit in 0.0012803077697753906 sec
Finished tracing + transforming scan for pjit in 0.00051116943359375 sec
Compiling scan with global shapes and types [ShapedArray(uint32[2]), ShapedArray(uint8[8156,32,128]), ShapedArray(uint8[8156]), ShapedArray(float32[20]), ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,20]), ShapedArray(int32[]), ShapedArray(float32[20]), ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,20]), ShapedArray(float32[20]), ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,20]), ShapedArray(int32[100])]. Argument mapping: [UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue].
Finished tracing + transforming less for pjit in 0.0004451274871826172 sec
Finished tracing + transforming add for pjit in 0.0003237724304199219 sec
Finished tracing + transforming add for pjit in 0.0003662109375 sec
Finished tracing + transforming bitwise_and for pjit in 0.0002689361572265625 sec
Finished tracing + transforming _threefry_seed for pjit in 0.0011444091796875 sec
Finished tracing + transforming ravel for pjit in 0.00017714500427246094 sec
Finished tracing + transforming threefry_2x32 for pjit in 0.0013267993927001953 sec
Finished tracing + transforming _threefry_fold_in for pjit in 0.0035865306854248047 sec
Finished tracing + transforming add for pjit in 0.00028014183044433594 sec
Finished tracing + transforming add for pjit in 0.0002779960632324219 sec
Finished tracing + transforming bitwise_or for pjit in 0.0002715587615966797 sec
Finished tracing + transforming bitwise_xor for pjit in 0.00040149688720703125 sec
Finished jaxpr to MLIR module conversion jit(scan) in 3.7547388076782227 sec
LLVM ERROR: mma16816 data type not supported
kmheckel commented 1 month ago

Jamie,

Haven't sat down yet to try working through this but my initial reaction is that you may just need to use an older version of CUDA, this seems like an error deeper in the tech stack than Spyx.

The CUDA 12.2 or 12.3 Ubuntu dev containers from NVIDIA on docker hub work well out of the box if you can use those.

On Sat, Jul 27, 2024, 12:24 neworderofjamie @.***> wrote:

It seems like mma16816 refers to a particular shape of "matrix fragment" https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-shape for sending to tensor cores. I don't really know very much about this but the problem seems to occur when XLA compiling scan (with JAX_LOG_COMPILES=1):

Finished tracing + transforming equal for pjit in 0.0010695457458496094 sec Finished tracing + transforming _broadcast_arrays for pjit in 0.0008301734924316406 sec Finished tracing + transforming _where for pjit in 0.002557516098022461 sec Finished tracing + transforming remainder for pjit in 0.006844043731689453 sec Finished tracing + transforming not_equal for pjit in 0.0008444786071777344 sec Finished tracing + transforming bitwise_and for pjit in 0.0007028579711914062 sec Finished tracing + transforming left_shift for pjit in 0.0010814666748046875 sec Finished tracing + transforming swapaxes for pjit in 0.0005633831024169922 sec Finished tracing + transforming bitwise_and for pjit in 0.0006837844848632812 sec Finished tracing + transforming greater for pjit in 0.0008101463317871094 sec Finished tracing + transforming swapaxes for pjit in 0.0004763603210449219 sec Finished tracing + transforming unpackbits for pjit in 0.009809017181396484 sec Finished tracing + transforming dot for pjit in 0.0012443065643310547 sec Finished tracing + transforming wrapped_fun for pjit in 0.002126932144165039 sec Finished tracing + transforming wrapped_fun for pjit in 0.002239704132080078 sec Finished tracing + transforming _reduce_sum for pjit in 0.0004973411560058594 sec Finished tracing + transforming equal for pjit in 0.0003476142883300781 sec Finished tracing + transforming _one_hot for pjit in 0.0012297630310058594 sec Finished tracing + transforming multiply for pjit in 0.0003788471221923828 sec Finished tracing + transforming add for pjit in 0.00038743019104003906 sec Finished tracing + transforming _reduce_max for pjit in 0.0007457733154296875 sec Finished tracing + transforming subtract for pjit in 0.0003197193145751953 sec Finished tracing + transforming exp for pjit in 0.0002810955047607422 sec Finished tracing + transforming _reduce_sum for pjit in 0.0004992485046386719 sec Finished tracing + transforming log for pjit in 0.0002684593200683594 sec Finished tracing + transforming log_softmax for pjit in 0.004455089569091797 sec Finished tracing + transforming multiply for pjit in 0.0003139972686767578 sec Finished tracing + transforming _reduce_sum for pjit in 0.0004935264587402344 sec Finished tracing + transforming negative for pjit in 0.00024700164794921875 sec Finished tracing + transforming _reduce_sum for pjit in 0.00037789344787597656 sec Finished tracing + transforming _mean for pjit in 0.0012240409851074219 sec Finished tracing + transforming net_eval for pjit in 1.7025067806243896 sec Finished tracing + transforming multiply for pjit in 0.0003402233123779297 sec Finished tracing + transforming add for pjit in 0.00036215782165527344 sec Finished tracing + transforming true_divide for pjit in 0.0003590583801269531 sec Finished tracing + transforming multiply for pjit in 0.00028133392333984375 sec Finished tracing + transforming multiply for pjit in 0.0004260540008544922 sec Finished tracing + transforming add for pjit in 0.00032019615173339844 sec Finished tracing + transforming multiply for pjit in 0.0003294944763183594 sec Finished tracing + transforming add for pjit in 0.0002956390380859375 sec Finished tracing + transforming multiply for pjit in 0.00032901763916015625 sec Finished tracing + transforming add for pjit in 0.00039267539978027344 sec Finished tracing + transforming multiply for pjit in 0.00033593177795410156 sec Finished tracing + transforming add for pjit in 0.0002865791320800781 sec Finished tracing + transforming less for pjit in 0.00032782554626464844 sec Finished tracing + transforming add for pjit in 0.0002799034118652344 sec Finished tracing + transforming _where for pjit in 0.0005338191986083984 sec Finished tracing + transforming _power for pjit in 0.0004990100860595703 sec Finished tracing + transforming subtract for pjit in 0.0003542900085449219 sec Finished tracing + transforming true_divide for pjit in 0.000286102294921875 sec Finished tracing + transforming true_divide for pjit in 0.0002815723419189453 sec Finished tracing + transforming true_divide for pjit in 0.0002815723419189453 sec Finished tracing + transforming true_divide for pjit in 0.0002837181091308594 sec Finished tracing + transforming tree_bias_correction for pjit in 0.005155086517333984 sec Finished tracing + transforming add for pjit in 0.00044655799865722656 sec Finished tracing + transforming sqrt for pjit in 0.00024175643920898438 sec Finished tracing + transforming true_divide for pjit in 0.00029349327087402344 sec Finished tracing + transforming add for pjit in 0.0003306865692138672 sec Finished tracing + transforming sqrt for pjit in 0.00024080276489257812 sec Finished tracing + transforming true_divide for pjit in 0.0002849102020263672 sec Finished tracing + transforming add for pjit in 0.00043964385986328125 sec Finished tracing + transforming sqrt for pjit in 0.0002624988555908203 sec Finished tracing + transforming true_divide for pjit in 0.0002868175506591797 sec Finished tracing + transforming add for pjit in 0.0003349781036376953 sec Finished tracing + transforming sqrt for pjit in 0.00024008750915527344 sec Finished tracing + transforming true_divide for pjit in 0.0002853870391845703 sec Finished tracing + transforming train_step for pjit in 7.95122504234314 sec Finished tracing + transforming _reduce_sum for pjit in 0.0004010200500488281 sec Finished tracing + transforming _mean for pjit in 0.0012803077697753906 sec Finished tracing + transforming scan for pjit in 0.00051116943359375 sec Compiling scan with global shapes and types [ShapedArray(uint32[2]), ShapedArray(uint8[8156,32,128]), ShapedArray(uint8[8156]), ShapedArray(float32[20]), ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,20]), ShapedArray(int32[]), ShapedArray(float32[20]), ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,20]), ShapedArray(float32[20]), ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,128]), ShapedArray(float32[128,20]), ShapedArray(int32[100])]. Argument mapping: [UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue]. Finished tracing + transforming less for pjit in 0.0004451274871826172 sec Finished tracing + transforming add for pjit in 0.0003237724304199219 sec Finished tracing + transforming add for pjit in 0.0003662109375 sec Finished tracing + transforming bitwise_and for pjit in 0.0002689361572265625 sec Finished tracing + transforming _threefry_seed for pjit in 0.0011444091796875 sec Finished tracing + transforming ravel for pjit in 0.00017714500427246094 sec Finished tracing + transforming threefry_2x32 for pjit in 0.0013267993927001953 sec Finished tracing + transforming _threefry_fold_in for pjit in 0.0035865306854248047 sec Finished tracing + transforming add for pjit in 0.00028014183044433594 sec Finished tracing + transforming add for pjit in 0.0002779960632324219 sec Finished tracing + transforming bitwise_or for pjit in 0.0002715587615966797 sec Finished tracing + transforming bitwise_xor for pjit in 0.00040149688720703125 sec Finished jaxpr to MLIR module conversion jit(scan) in 3.7547388076782227 sec LLVM ERROR: mma16816 data type not supported

— Reply to this email directly, view it on GitHub https://github.com/kmheckel/spyx/issues/31#issuecomment-2254105712, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMG7YKGYSZ5MW4HUBLL3IHTZONYOHAVCNFSM6AAAAABLQITADOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENJUGEYDKNZRGI . You are receiving this because you commented.Message ID: @.***>

neworderofjamie commented 1 month ago

Totally agree,it's happening deeper -XLA perhaps. What version of JAX would you recommend on the 12.3 container? Current pip version requires 12.5 (hence my upgrading )

kmheckel commented 1 month ago

I think jaxlib 0.4.24 though 0.4.26 might work? 4.24 was the first to add CUDA 12.3 so it should work.

On Sat, Jul 27, 2024, 15:10 neworderofjamie @.***> wrote:

Totally agree,it's happening deeper -XLA perhaps. What version of JAX would you recommend on the 12.3 container? Current pip version requires 12.5 (hence my upgrading )

— Reply to this email directly, view it on GitHub https://github.com/kmheckel/spyx/issues/31#issuecomment-2254145970, or unsubscribe https://github.com/notifications/unsubscribe-auth/AMG7YKG6IW4BKOEH2BUGJMTZOOL5LAVCNFSM6AAAAABLQITADOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENJUGE2DKOJXGA . You are receiving this because you commented.Message ID: @.***>

neworderofjamie commented 1 month ago

Any idea what an appropriate version of torch would be to work with the same versions of CUDA/CUDNN as Jax 0.4.26? Your SHD model uses torch data loaders and I am really struggling to find versions of the CUDA packages, torch and jax which are compatible. If you could pip freeze me a requirements.txt from a working environment, that would be amazing

kmheckel commented 1 month ago

Sure thing Jamie, I'll try to find one this morning and add it to the repo. As a short term work around the move might be to get an environment where just the torch dataloaders work, load the data, and then pickle it before loading it in an environment where JAX is working but this is definitely not ideal. It's been on the TODO list to find a good replacement for torchdata loaders to avoid this kind of issue.

kmheckel commented 1 month ago

Still working this, another workaround would be to install the cpu only version of pytorch to avoid extra CUDA issues.

Still searching for a good CUDA setup to avoid the LLVM error. If I can't find something in the next hour or so we should kick this error over to the main JAX repo as LLVM crashes are high priority bugs for them.

kmheckel commented 1 month ago

Here is a known working config:

CUDA 12.3

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip3 install jax[cuda12]==0.4.26
pip3 install spyx tonic

I'll push the freeze.txt file to the repo so you can pull it as well. The installation instructions need updated as well as the setup.py. Additionally, the installation process was changed to spyx[loaders], so I must've missed to update that note somewhere.

neworderofjamie commented 1 month ago

Wonderful, thank you! Should have thought of a CPU-only PyTorch install. Can you reproduce the LLVM crash? If so, you might be better-placed to raise an issue, if not, I'm happy to do so.

neworderofjamie commented 1 month ago

When I try installing spyx[loaders], I also get a :

The 'sklearn' PyPI package is deprecated, use 'scikit-learn'
    rather than 'sklearn' for pip commands.
neworderofjamie commented 1 month ago

And I'm afraid the example now fails with:

Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 605, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax_plugins/xla_cuda12/__init__.py", line 83, in initialize
    xla_client.register_custom_call_handler(
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jaxlib/xla_client.py", line 658, in register_custom_call_handler
    handler(name, fn, xla_platform_name, api_version, traits)
TypeError: register_custom_call_target(): incompatible function arguments. The following argument types are supported:
    1. register_custom_call_target(c_api: capsule, fn_name: str, fn: capsule, xla_platform_name: str, api_version: int = 0) -> None

Invoked with types: PyCapsule, str, PyCapsule, str, int, jaxlib.xla_client.CustomCallTargetTraits
Traceback (most recent call last):
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 879, in backends
    backend = _init_backend(platform)
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 970, in _init_backend
    backend = registration.factory()
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 662, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jaxlib/xla_client.py", line 177, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.54).

During handling of the above exception, another exception occurred:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/its/home/jk421/Documents/SHD_jax.py", line 71, in <module>
    x_train = jnp.array(x_train, dtype=jnp.uint8)
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3214, in array
    out_array: Array = lax_internal._convert_element_type(
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 559, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/core.py", line 416, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/core.py", line 921, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/its/home/jk421/virtualenv/spyx/lib/python3.9/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.54). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
kmheckel commented 1 month ago

When I try installing spyx[loaders], I also get a :

The 'sklearn' PyPI package is deprecated, use 'scikit-learn'
    rather than 'sklearn' for pip commands.

This should be fixed now, as I just pushed a tweak to the setup.py file to require scikit-learn instead of sklearn and also changed the torchvision requirement to be the CPU version. I'll have to release a new version to PyPI once this is sorted.

If you could list the results of pip list | grep jax that would help, seems like jax-cuda12-pjrt may not equal 0.4.26?

the working_12_3.txt file in the main directory of the repo lists the environment I got working when using the nvidia/cuda_12.3.2-devel-ubuntu22.04 docker container on an A4000 on Vast.ai

neworderofjamie commented 1 month ago

Oh yeah, sorry, spyx upgraded "jax>=0.4.27" and jaxlib to a newer version while somehow leaving behind the CUDA12 bit:

jax                      0.4.30    
jax-cuda12-pjrt          0.4.26    
jax-cuda12-plugin        0.4.26    
jaxlib                   0.4.30 

it's working now though - thanks for all your help.

kmheckel commented 1 month ago

Awesome - I'm swamped for the next two weeks so I don't have the bandwidth to run down the LLVM crash issue but glad this works in the meantime. If there's anything else experiments-wise you want to discuss feel free to shoot me an email; I know the Spyx benchmarks need updated across the board.