srush / annotated-s4

Implementation of https://srush.github.io/annotated-s4
https://srush.github.io/annotated-s4
MIT License
460 stars 61 forks source link

DNN library initialization failed #75

Closed maisuiqianxun closed 1 year ago

maisuiqianxun commented 1 year ago

Is there the case happened after installing the depencencies from the file "requirements-gpu"?

/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/flax/struct.py:133: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use register_pytree_with_keys() instead. jax.tree_util.register_keypaths(data_clz, keypaths) /home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/flax/struct.py:133: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use register_pytree_with_keys() instead. jax.tree_util.register_keypaths(data_clz, keypaths) 2023-09-18 22:31:54.050559: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT /home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/flax/struct.py:133: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use register_pytree_with_keys() instead. jax.tree_util.register_keypaths(data_clz, keypaths) dataset: mnist layer: s4 seed: 0 model: d_model: 128 n_layers: 4 dropout: 0.0 prenorm: true embedding: false layer: 'N': 64 train: epochs: 100 bsz: 128 lr: 0.001 lr_schedule: false weight_decay: 0.01 checkpoint: false suffix: null sample: null wandb: mode: disabled project: s4 entity: null

[] Warning: models are not being checkpoint [] Setting Randomness... [2023-09-18 22:31:55,075][jax._src.xla_bridge][INFO] - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Host Interpreter [2023-09-18 22:31:55,075][jax._src.xla_bridge][INFO] - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client' [2023-09-18 22:31:55,075][jax._src.xla_bridge][INFO] - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this. 2023-09-18 22:31:55.360848: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR Error executing job with overrides: ['dataset=mnist', 'layer=s4', 'train.epochs=100', 'train.bsz=128', 'model.d_model=128', 'model.layer.N=64'] Traceback (most recent call last): File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/user/ykx/annotated-s4-main/s4/train.py", line 468, in main() File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/main.py", line 94, in decorated_main _run_hydra( File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra _run_app( File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 457, in _run_app run_and_report( File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 223, in run_and_report raise ex File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 220, in run_and_report return func() File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/_internal/utils.py", line 458, in lambda: hydra.run( File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/internal/hydra.py", line 132, in run = ret.return_value File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/core/utils.py", line 260, in return_value raise self._return_value File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/hydra/core/utils.py", line 186, in run_job ret.return_value = task_function(task_cfg) File "/home/user/ykx/annotated-s4-main/s4/train.py", line 464, in main example_train(cfg) File "/home/user/ykx/annotated-s4-main/s4/train.py", line 312, in example_train key = jax.random.PRNGKey(seed) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey key = prng.seed_with_impl(impl, seed) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 270, in seed_with_impl return random_seed(seed, impl=impl) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 561, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 359, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 362, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 816, in process_primitive return primitive.impl(*tracers, *params) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 573, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base return seed(seeds) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/prng.py", line 813, in threefry_seed lax.shift_right_logical(seed, lax_internal._const(seed, 32))) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical return shift_right_logical_p.bind(x, y) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 359, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 362, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/core.py", line 816, in process_primitive return primitive.impl(tracers, params) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive compiled_fun = xla_primitive_callable(prim, unsafe_map(arg_spec, args), File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper return cached(config._trace_context(), args, kwargs) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached return f(*args, *kwargs) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name, File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile self._executable = UnloadedMeshExecutable.from_hlo( File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo xla_executable = dispatch.compile_or_get_cached( File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached return backend_compile(backend, serialized_computation, compile_options, File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/user/anaconda3/envs/JAX-GPU/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile return backend.compile(built_c, compile_options=options) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

maisuiqianxun commented 1 year ago

Here is the installed packages of the conda environment: absl-py 1.4.0 aiohttp 3.8.5 aiosignal 1.3.1 antlr4-python3-runtime 4.9.3 appdirs 1.4.4 asttokens 2.4.0 astunparse 1.6.3 async-timeout 4.0.3 attrs 21.4.0 backcall 0.2.0 black 23.9.1 cachetools 5.3.1 celluloid 0.2.0 certifi 2023.7.22 cfgv 3.4.0 charset-normalizer 3.2.0 chex 0.1.5 click 8.1.7 cycler 0.11.0 datasets 2.14.5 decorator 5.1.1 dill 0.3.7 distlib 0.3.7 dm-tree 0.1.8 docker-pycreds 0.4.0 etils 1.4.1 exceptiongroup 1.1.3 executing 1.2.0 fastjsonschema 2.18.0 filelock 3.12.4 flake8 6.1.0 flatbuffers 2.0.7 flax 0.5.0 fonttools 4.42.1 frozenlist 1.4.0 fsspec 2023.6.0 gast 0.4.0 gitdb 4.0.10 GitPython 3.1.36 google-api-core 2.11.1 google-auth 2.23.0 google-auth-oauthlib 1.0.0 google-cloud-core 2.3.3 google-cloud-storage 2.10.0 google-crc32c 1.5.0 google-pasta 0.2.0 google-resumable-media 2.6.0 googleapis-common-protos 1.60.0 grpcio 1.58.0 h5py 3.9.0 huggingface-hub 0.17.1 hydra-core 1.3.2 identify 2.5.29 idna 3.4 importlib-metadata 6.8.0 importlib-resources 6.0.1 ipython 8.15.0 isort 5.12.0 jax 0.4.7 jaxlib 0.4.7+cuda11.cudnn82 jedi 0.19.0 jsonschema 4.17.3 jupyter_core 5.3.1 jupytext 1.13.4 keras 2.13.1 kiwisolver 1.4.5 libclang 16.0.6 Markdown 3.4.4 markdown-it-py 1.1.0 MarkupSafe 2.1.3 matplotlib 3.5.1 matplotlib-inline 0.1.6 mccabe 0.7.0 mdit-py-plugins 0.4.0 ml-dtypes 0.2.0 msgpack 1.0.5 multidict 6.0.4 multiprocess 0.70.15 mypy-extensions 1.0.0 nbformat 5.9.2 nodeenv 1.8.0 numpy 1.24.3 oauthlib 3.2.2 omegaconf 2.3.0 opt-einsum 3.3.0 optax 0.1.7 packaging 23.1 pandas 2.1.0 parso 0.8.3 pathspec 0.11.2 pathtools 0.1.2 pexpect 4.8.0 pickleshare 0.7.5 Pillow 10.0.1 pip 23.2.1 platformdirs 3.10.0 pre-commit 2.19.0 prompt-toolkit 3.0.39 protobuf 4.24.3 psutil 5.9.5 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 13.0.0 pyasn1 0.5.0 pyasn1-modules 0.3.0 pycodestyle 2.11.0 pyflakes 3.1.0 Pygments 2.16.1 pyparsing 3.1.1 pyrsistent 0.19.3 python-dateutil 2.8.2 pytz 2023.3.post1 PyYAML 6.0.1 requests 2.31.0 requests-oauthlib 1.3.1 rsa 4.9 scipy 1.11.2 seaborn 0.12.2 sentry-sdk 1.31.0 setproctitle 1.3.2 setuptools 68.0.0 six 1.16.0 smmap 5.0.1 stack-data 0.6.2 tensorboard 2.13.0 tensorboard-data-server 0.7.1 tensorflow 2.13.0 tensorflow-estimator 2.13.0 tensorflow-io-gcs-filesystem 0.34.0 termcolor 2.3.0 toml 0.10.2 tomli 2.0.1 toolz 0.12.0 torch 1.10.1 torchaudio 0.10.1 torchfsdd 0.1.1 torchtext 0.11.1 torchvision 0.11.2 tqdm 4.66.1 traitlets 5.10.0 typing_extensions 4.5.0 tzdata 2023.3 urllib3 1.26.16 virtualenv 20.24.5 wandb 0.15.10 wcwidth 0.2.6 Werkzeug 2.3.7 wheel 0.38.4 wrapt 1.15.0 xxhash 3.3.0 yarl 1.9.2 zipp 3.16.2

albertfgu commented 1 year ago

I haven't seen this error. There might be an issue with the environment or library versions that I'm unable to help with.

maisuiqianxun commented 1 year ago

@albertfgu Thanks for your reply. I have solved this issue after upgrading my cudnn version and re-installing the corresponding packages. Now it works.

maisuiqianxun commented 1 year ago

Some information may help others: ubuntu 20.04; cuda 11.3; cudnn 8.9.4.25-1+cuda11.8; jaxlib 0.4.16+cuda11.cudnn86; jax 0.4.16; python 3.9.