google-research / multinerf

A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF
Apache License 2.0
3.56k stars 338 forks source link

Failed to run with Jax 0.4.18 #139

Open yuehaowang opened 9 months ago

yuehaowang commented 9 months ago

I was running RawNeRF with the latest Jax 0.4.18 but encountered the error message below after training ~300 iterations:

INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: Failed to update gpu graph: Graph update result=kNodeTypeChanged: Failed to update CUDA graph: CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE: the graph update was not performed because it included changes which violated constraints specific to instantiated graph update; current profiling annotation: XlaModule:#hlo_module=pmap_train_step,program_id=237#.

After downgrading Jax from 0.4.18 to 0.4.16, this error was gone.

I was using CUDA 11.8. I installed Jax via jax[cuda11_local]. The installed packages were jax v0.4.18, jaxlib 0.4.18+cuda11.cudnn86. Not sure if this is due to conflicts with other packages.

deoxyribose commented 8 months ago

I encountered this error

ValueError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: CaptureGpuGraph failed (the requested functionality is not supported; current tracing scope: custom-call.119): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current profiling annotation: XlaModule:#hlo_module=jit_body_fn,program_id=214#.

which was also solved by downgrading to 0.4.16 (thanks @yuehaowang), copypasting it here so it's google-able.