TolimanSpace / toliman-flake

0 stars 0 forks source link

jaxlib on jetson #2

Open christhechris opened 4 months ago

christhechris commented 4 months ago

usefull links

https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/index.html#from-network-repositories-or-local-installers https://github.com/google/jax https://forums.developer.nvidia.com/t/jax-on-jetson-nano/182593

notes

christhechris commented 4 months ago

working through some errors: https://chat.openai.com/share/97c3be3b-8753-4c48-9c8c-68f684a69ff0

christhechris commented 4 months ago

proper build cmd python3 build/build.py --enable_cuda --cuda_compute_capabilities compute_72 --cuda_path=$CUDA_HOME --cudnn_path=$CUDNN_HOME

Finally gives cuda version error as build fail reason jaxlib/cuda/versions_helpers.cc:27:2: error: #error "JAX requires CUDA 11.8 or newer."

christhechris commented 4 months ago

nope https://github.com/google/jax/blob/main/CHANGELOG.md#jax-048-march-29-2023 11.4 was dropped earlier then flux required version. Jax 0.4.13+

arduano commented 4 months ago

Some things to note:

The way anduril packaged cuda is very non-standard relative to the way nixpkgs does it, so many assumptions are broken, and I had to make a big hack to fix it over in the shared overlay file. This is because Nvidia doesn't actually provide a portable SDK, especially one you can compile, instead it just provides some raw .deb files for the Jetson.

Also, one of the best ways of testing if the environment is applied correctly is to use the jax-test python executable that Connor and I made, and run commands inside there to query the environment. Nix doesn't have global state generally (except for system level stuff like the kernel and gpu drivers), so it's difficult to know what's going on inside a derivation without just getting the code that's running in the derivation to tell you

However, things like nvidia-smi should be functional globally, so if they're not then that could be a symptom of other issues too.

christhechris commented 4 months ago

can cuda part of the global state?

arduano commented 4 months ago

Yes, but that technically goes against Nix principles, though if it works then it works. I'm not sure how to make cuda part of global state, maybe you need to set some environment variables? Shouldn't be too hard if that's the case: https://nixos.wiki/wiki/Environment_variables

christhechris commented 4 months ago

all good. going to focus for the moment on getting jax working on normal JetPack 5.1.3 LT4 35.1 system, get wheels built etc. and go from there for now.

christhechris commented 4 months ago

final found correct way to update cuda https://developer.nvidia.com/cuda-12-2-0-download-archive?target_os=Linux&target_arch=aarch64-jetson&Compilation=Native&Distribution=Ubuntu&target_version=20.04&target_type=deb_network

CLangford2098 commented 4 months ago

Encountering a new issue now. Have a similar issue to this: https://github.com/google/jax/issues/5723.

Seems like a driver and CUDA version incompatibility. Will investigate further

Cuda version: 12.2.140 Driver version: 35.5.0 (running grep "X Driver" /var/log/Xorg.0.log)

christhechris commented 4 months ago

Can you share your error


From: Connor Langford @.> Sent: Monday, March 18, 2024 11:52:18 AM To: TolimanSpace/toliman-flake @.> Cc: Chris Betters @.>; Assign @.> Subject: Re: [TolimanSpace/toliman-flake] jaxlib on jetson (Issue #2)

Encountering a new issue now. Have a similar issue to this: google/jax#5723https://github.com/google/jax/issues/5723.

Seems like a driver and CUDA version incompatibility. Will investigate further

— Reply to this email directly, view it on GitHubhttps://github.com/TolimanSpace/toliman-flake/issues/2#issuecomment-2002694324, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AAF7CQ6BEYALIVFDQ5XMEI3YYY3EFAVCNFSM6AAAAABEVGJ326VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBSGY4TIMZSGQ. You are receiving this because you were assigned.Message ID: @.***>

CLangford2098 commented 4 months ago

`E0318 11:57:21.693717 5506 pjrt_stream_executor_client.cc:2813] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: no kernel image is available for execution on the device jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "", line 1, in File "/home/toliman/jax/jax/_src/random.py", line 711, in normal return _normal(key, shape, dtype) # type: ignore jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: no kernel image is available for execution on the device`

Running the "Multiplying Matrices" Code from here: https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

CLangford2098 commented 4 months ago

I've just added the versions to my initial comment. I'm concerned about the driver version. "We recommend installing the newest driver available from NVIDIA, but the driver must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux." I tried installing driver version 535 but I don't think it's using that (and I think the installation didn't work)

christhechris commented 4 months ago

fix was to specify compute mode, `python3 build/build.py --enable_cuda --cuda_compute_capabilities=sm_72'