Open yuehaowang opened 9 months ago
I encountered this error
ValueError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: CaptureGpuGraph failed (the requested functionality is not supported; current tracing scope: custom-call.119): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current profiling annotation: XlaModule:#hlo_module=jit_body_fn,program_id=214#.
which was also solved by downgrading to 0.4.16 (thanks @yuehaowang), copypasting it here so it's google-able.
I was running RawNeRF with the latest Jax 0.4.18 but encountered the error message below after training ~300 iterations:
After downgrading Jax from 0.4.18 to 0.4.16, this error was gone.
I was using CUDA 11.8. I installed Jax via
jax[cuda11_local]
. The installed packages werejax v0.4.18
,jaxlib 0.4.18+cuda11.cudnn86
. Not sure if this is due to conflicts with other packages.