PredictiveIntelligenceLab / cvit

MIT License
19 stars 2 forks source link

How to solve Jax error #1

Open ramdhan1989 opened 1 week ago

ramdhan1989 commented 1 week ago

Hi, I have problem in JAX installation. I used cluster to run. this is the error : `Building wheels for collected packages: cvit Building wheel for cvit (setup.py): started Building wheel for cvit (setup.py): finished with status 'done' Created wheel for cvit: filename=cvit-0.0.1-py3-none-any.whl size=8459 sha256=c99748fe77dff6df3699504a7d06ddbda106cbe4351d5c2dfc1f93084f13c445 Stored in directory: /tmp/SLURM_27474459/pip-ephem-wheel-cache-n1yylbvy/wheels/50/76/1e/0b6ca6fb3ba3e7ab9519c0f80e80cca7b3f8d594b790a57c90 Successfully built cvit Installing collected packages: cvit Attempting uninstall: cvit Found existing installation: cvit 0.0.1 Uninstalling cvit-0.0.1: Successfully uninstalled cvit-0.0.1 Successfully installed cvit-0.0.1 /home1/rwibawa/.local/lib/python3.11/site-packages/absl/flags/_validators.py:254: UserWarning: Flag --config has a non-None default value; therefore, mark_flag_as_required will pass even if flag is not specified in the command line! mark_flag_as_required(flag_name, flag_values) 2024-11-14 19:38:14.382678: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Traceback (most recent call last): File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 896, in backends backend = _init_backend(platform) ^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 982, in _init_backend backend = registration.factory() ^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/xla_bridge.py", line 674, in factory return xla_client.make_c_api_client(plugin_name, updated_options, None) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jaxlib/xla_client.py", line 200, in make_c_api_client return _xla.get_c_api_client(plugin_name, options, distributed_client) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/scratch1/rwibawa/cvit/ns/main.py", line 29, in app.run(main) File "/home1/rwibawa/.local/lib/python3.11/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home1/rwibawa/.local/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/scratch1/rwibawa/cvit/ns/main.py", line 21, in main train.train_and_evaluate(FLAGS.config) File "/scratch1/rwibawa/cvit/ns/train.py", line 30, in train_and_evaluate state = create_train_state(config, model, tx) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/src/utils.py", line 32, in create_train_state x = jnp.ones(config.x_dim) ^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5949, in ones return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1615, in full fill_value = _convert_element_type(fill_value, dtype, weak_type) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 587, in _convert_element_type return convert_element_type_p.bind( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 2981, in _convert_element_type_bind operand = core.Primitive.bind(convert_element_type_p, operand, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/core.py", line 438, in bind return self.bind_with_trace(find_top_trace(args), args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/core.py", line 442, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/core.py", line 955, in process_primitive return primitive.impl(*tracers, *params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive outs = fun(args) ^^^^^^^^^^ RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.`

Thank you

ramdhan1989 commented 1 day ago

Update error:

I1120 15:16:37.697383 140683125090112 xla_bridge.py:906] Unable to initialize backend 'cuda': I1120 15:16:37.698054 140683125090112 xla_bridge.py:906] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' I1120 15:16:37.716493 140683125090112 xla_bridge.py:906] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory W1120 15:16:37.717142 140683125090112 xla_bridge.py:948] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. I1120 15:17:10.983708 140683125090112 checkpoint_manager.py:557] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers=None, handler_registry=None

slurm file

!/bin/bash

SBATCH --ntasks=1

SBATCH --cpus-per-task=8

SBATCH --time=48:00:00

SBATCH --account=xxxxx

SBATCH --partition=gpu

SBATCH --gres=gpu:1

SBATCH --mem=128G

module purge module load gcc/9.2.0 module load python/3.12.1 module load conda mamba init bash source ~/.bashrc

cd /scratch1/rwibawa/cvit

python -m pip install --upgrade pip python -m pip install --upgrade -r requirements.txt conda install "jaxlib==cuda*" jax cuda-nvcc -c conda-forge -c nvidia python -m pip install wandb==0.17.2 python -m pip install absl_py

cd /scratch1/rwibawa/cvit/ns

CUDA_VISIBLE_DEVICES=[0] python main.py --config=configs/cvit_4x4.py