Closed maisuiqianxun closed 1 year ago
Here is the installed packages of the conda environment: absl-py 1.4.0 aiohttp 3.8.5 aiosignal 1.3.1 antlr4-python3-runtime 4.9.3 appdirs 1.4.4 asttokens 2.4.0 astunparse 1.6.3 async-timeout 4.0.3 attrs 21.4.0 backcall 0.2.0 black 23.9.1 cachetools 5.3.1 celluloid 0.2.0 certifi 2023.7.22 cfgv 3.4.0 charset-normalizer 3.2.0 chex 0.1.5 click 8.1.7 cycler 0.11.0 datasets 2.14.5 decorator 5.1.1 dill 0.3.7 distlib 0.3.7 dm-tree 0.1.8 docker-pycreds 0.4.0 etils 1.4.1 exceptiongroup 1.1.3 executing 1.2.0 fastjsonschema 2.18.0 filelock 3.12.4 flake8 6.1.0 flatbuffers 2.0.7 flax 0.5.0 fonttools 4.42.1 frozenlist 1.4.0 fsspec 2023.6.0 gast 0.4.0 gitdb 4.0.10 GitPython 3.1.36 google-api-core 2.11.1 google-auth 2.23.0 google-auth-oauthlib 1.0.0 google-cloud-core 2.3.3 google-cloud-storage 2.10.0 google-crc32c 1.5.0 google-pasta 0.2.0 google-resumable-media 2.6.0 googleapis-common-protos 1.60.0 grpcio 1.58.0 h5py 3.9.0 huggingface-hub 0.17.1 hydra-core 1.3.2 identify 2.5.29 idna 3.4 importlib-metadata 6.8.0 importlib-resources 6.0.1 ipython 8.15.0 isort 5.12.0 jax 0.4.7 jaxlib 0.4.7+cuda11.cudnn82 jedi 0.19.0 jsonschema 4.17.3 jupyter_core 5.3.1 jupytext 1.13.4 keras 2.13.1 kiwisolver 1.4.5 libclang 16.0.6 Markdown 3.4.4 markdown-it-py 1.1.0 MarkupSafe 2.1.3 matplotlib 3.5.1 matplotlib-inline 0.1.6 mccabe 0.7.0 mdit-py-plugins 0.4.0 ml-dtypes 0.2.0 msgpack 1.0.5 multidict 6.0.4 multiprocess 0.70.15 mypy-extensions 1.0.0 nbformat 5.9.2 nodeenv 1.8.0 numpy 1.24.3 oauthlib 3.2.2 omegaconf 2.3.0 opt-einsum 3.3.0 optax 0.1.7 packaging 23.1 pandas 2.1.0 parso 0.8.3 pathspec 0.11.2 pathtools 0.1.2 pexpect 4.8.0 pickleshare 0.7.5 Pillow 10.0.1 pip 23.2.1 platformdirs 3.10.0 pre-commit 2.19.0 prompt-toolkit 3.0.39 protobuf 4.24.3 psutil 5.9.5 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 13.0.0 pyasn1 0.5.0 pyasn1-modules 0.3.0 pycodestyle 2.11.0 pyflakes 3.1.0 Pygments 2.16.1 pyparsing 3.1.1 pyrsistent 0.19.3 python-dateutil 2.8.2 pytz 2023.3.post1 PyYAML 6.0.1 requests 2.31.0 requests-oauthlib 1.3.1 rsa 4.9 scipy 1.11.2 seaborn 0.12.2 sentry-sdk 1.31.0 setproctitle 1.3.2 setuptools 68.0.0 six 1.16.0 smmap 5.0.1 stack-data 0.6.2 tensorboard 2.13.0 tensorboard-data-server 0.7.1 tensorflow 2.13.0 tensorflow-estimator 2.13.0 tensorflow-io-gcs-filesystem 0.34.0 termcolor 2.3.0 toml 0.10.2 tomli 2.0.1 toolz 0.12.0 torch 1.10.1 torchaudio 0.10.1 torchfsdd 0.1.1 torchtext 0.11.1 torchvision 0.11.2 tqdm 4.66.1 traitlets 5.10.0 typing_extensions 4.5.0 tzdata 2023.3 urllib3 1.26.16 virtualenv 20.24.5 wandb 0.15.10 wcwidth 0.2.6 Werkzeug 2.3.7 wheel 0.38.4 wrapt 1.15.0 xxhash 3.3.0 yarl 1.9.2 zipp 3.16.2
I haven't seen this error. There might be an issue with the environment or library versions that I'm unable to help with.
@albertfgu Thanks for your reply. I have solved this issue after upgrading my cudnn version and re-installing the corresponding packages. Now it works.
Some information may help others: ubuntu 20.04; cuda 11.3; cudnn 8.9.4.25-1+cuda11.8; jaxlib 0.4.16+cuda11.cudnn86; jax 0.4.16; python 3.9.
Is there the case happened after installing the depencencies from the file "requirements-gpu"?
/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/flax/struct.py:133: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use
register_pytree_with_keys()
instead. jax.tree_util.register_keypaths(data_clz, keypaths) /home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/flax/struct.py:133: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please useregister_pytree_with_keys()
instead. jax.tree_util.register_keypaths(data_clz, keypaths) 2023-09-18 22:31:54.050559: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT /home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/flax/struct.py:133: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please useregister_pytree_with_keys()
instead. jax.tree_util.register_keypaths(data_clz, keypaths) dataset: mnist layer: s4 seed: 0 model: d_model: 128 n_layers: 4 dropout: 0.0 prenorm: true embedding: false layer: 'N': 64 train: epochs: 100 bsz: 128 lr: 0.001 lr_schedule: false weight_decay: 0.01 checkpoint: false suffix: null sample: null wandb: mode: disabled project: s4 entity: null[] Warning: models are not being checkpoint [] Setting Randomness... [2023-09-18 22:31:55,075][jax._src.xla_bridge][INFO] - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Host Interpreter [2023-09-18 22:31:55,075][jax._src.xla_bridge][INFO] - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client' [2023-09-18 22:31:55,075][jax._src.xla_bridge][INFO] - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this. 2023-09-18 22:31:55.360848: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR Error executing job with overrides: ['dataset=mnist', 'layer=s4', 'train.epochs=100', 'train.bsz=128', 'model.d_model=128', 'model.layer.N=64'] Traceback (most recent call last): File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/user/ykx/annotated-s4-main/s4/train.py", line 468, in
main()
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 458, in
lambda: hydra.run(
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/internal/hydra.py", line 132, in run
= ret.return_value
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "/home/user/ykx/annotated-s4-main/s4/train.py", line 464, in main
example_train(cfg)
File "/home/user/ykx/annotated-s4-main/s4/train.py", line 312, in example_train
key = jax.random.PRNGKey(seed)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 561, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 359, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 362, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 816, in process_primitive
return primitive.impl(*tracers, *params)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
return seed(seeds)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 813, in threefry_seed
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 359, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 362, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 816, in process_primitive
return primitive.impl(tracers, params)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
compiled_fun = xla_primitive_callable(prim, unsafe_map(arg_spec, args),
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), args, kwargs)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
return f(*args, *kwargs)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(args, kwargs)
File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.