I am trying to test my installation of Jax-ReaxFF. Although I have a GPU, I do not have root access so installation of CUDA and CUDNN is quite cumbersome. I would therefore like to use the CPU version instead.
When I run the command for the benchmark test, it however fails when starting the iteration process. The error seems to be thrown because the code still searches for cuda, which is not installed. How can I overwrite this behavior?
jaxreaxff --init_FF Datasets/cobalt/ffield_lit \
> --params Datasets/cobalt/params \
> --geo Datasets/cobalt/geo \
> --train_file Datasets/cobalt/trainset.in \
> --num_e_minim_steps 200 \
> --e_minim_LR 1e-3 \
> --out_folder ffields \
> --save_opt all \
> --num_trials 1 \
> --num_steps 20 \
> --init_FF_type fixed \
> --backend cpu
Selected backend for JAX: CPU
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[INFO] Force field field is read
[INFO] Parameter file is read, there are 12 parameters to be optimized!
[INFO] trainset file is read, there are 144 items
ENERGY:144
[INFO] Geometry file is read, there are 147 geometries and 130 require energy minimization!
After removing geometries that are not used in the trainset file:
[INFO] Geometry file is read, there are 146 geometries and 129 require energy minimization!
Multithreaded interaction list generation took 8.53 secs with 16 threads
Cost without aligning: 8697561.0
nonbounded: 8635439.0
bounded: 62122.0
Number of clusters: 6
Cost after aligning: 11500945.0
nonbounded: 11423457.0
bounded: 77488.0
****************************************
Trial-1 is starting...
****************************************
Iteration: 0
Traceback (most recent call last):
File "/work/dumortil/miniconda3/envs/jax-env/bin/jaxreaxff", line 33, in <module>
sys.exit(load_entry_point('jaxreaxff==0.1.0', 'console_scripts', 'jaxreaxff')())
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jaxreaxff-0.1.0-py3.8.egg/jaxreaxff/driver.py", line 342, in main
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jaxreaxff-0.1.0-py3.8.egg/jaxreaxff/optimizer.py", line 846, in train_FF
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/api.py", line 427, in cache_miss
out_flat = xla.xla_call(
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/core.py", line 1690, in bind
return call_bind(self, fun, *args, **params)
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/core.py", line 1702, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/core.py", line 601, in process_call
return primitive.impl(f, *tracers, **params)
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/dispatch.py", line 142, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, *args)
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/dispatch.py", line 212, in lower_xla_callable
device = _xla_callable_device(nreps, backend, device, arg_devices)
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/dispatch.py", line 362, in _xla_callable_device
return xb.get_backend(backend).get_default_device_assignment(1)[0]
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/lib/xla_bridge.py", line 301, in get_backend
return _get_backend_uncached(platform)
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/lib/xla_bridge.py", line 291, in _get_backend_uncached
raise RuntimeError(f"Backend '{platform}' failed to initialize: "
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Backend 'gpu' failed to initialize: NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/work/dumortil/miniconda3/envs/jax-env/bin/jaxreaxff", line 33, in <module>
sys.exit(load_entry_point('jaxreaxff==0.1.0', 'console_scripts', 'jaxreaxff')())
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jaxreaxff-0.1.0-py3.8.egg/jaxreaxff/driver.py", line 342, in main
File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jaxreaxff-0.1.0-py3.8.egg/jaxreaxff/optimizer.py", line 846, in train_FF
RuntimeError: Backend 'gpu' failed to initialize: NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
I've found the error.
In helper.py, the gpu is hardcoded. Replacing it by 'cpu' and reinstalling solved my issue, however this is inconsistent with the flag.
I am trying to test my installation of Jax-ReaxFF. Although I have a GPU, I do not have root access so installation of CUDA and CUDNN is quite cumbersome. I would therefore like to use the CPU version instead. When I run the command for the benchmark test, it however fails when starting the iteration process. The error seems to be thrown because the code still searches for
cuda
, which is not installed. How can I overwrite this behavior?