cagrikymk / JAX-ReaxFF

JAX-ReaxFF: A Gradient Based Framework for Extremely Fast Optimization of Reactive Force Fields
GNU General Public License v3.0
53 stars 22 forks source link

--backup cpu ignored during installation test #2

Closed lcdumort closed 2 years ago

lcdumort commented 2 years ago

I am trying to test my installation of Jax-ReaxFF. Although I have a GPU, I do not have root access so installation of CUDA and CUDNN is quite cumbersome. I would therefore like to use the CPU version instead. When I run the command for the benchmark test, it however fails when starting the iteration process. The error seems to be thrown because the code still searches for cuda, which is not installed. How can I overwrite this behavior?

jaxreaxff --init_FF Datasets/cobalt/ffield_lit             \
>           --params Datasets/cobalt/params                  \
>           --geo Datasets/cobalt/geo                        \
>           --train_file Datasets/cobalt/trainset.in         \
>           --num_e_minim_steps 200                          \
>           --e_minim_LR 1e-3                                \
>           --out_folder ffields                             \
>           --save_opt all                                   \
>           --num_trials 1                                   \
>           --num_steps 20                                   \
>           --init_FF_type fixed                             \
>           --backend cpu
Selected backend for JAX: CPU
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[INFO] Force field field is read
[INFO] Parameter file is read, there are 12 parameters to be optimized!
[INFO] trainset file is read, there are 144 items
ENERGY:144
[INFO] Geometry file is read, there are 147 geometries and 130 require energy minimization!
After removing geometries that are not used in the trainset file:
[INFO] Geometry file is read, there are 146 geometries and 129 require energy minimization!
Multithreaded interaction list generation took 8.53 secs with 16 threads
Cost without aligning: 8697561.0
nonbounded:            8635439.0
bounded:               62122.0
Number of clusters:    6
Cost after aligning:   11500945.0
nonbounded:            11423457.0
bounded:               77488.0
****************************************
Trial-1 is starting...
****************************************
Iteration: 0
Traceback (most recent call last):
  File "/work/dumortil/miniconda3/envs/jax-env/bin/jaxreaxff", line 33, in <module>
    sys.exit(load_entry_point('jaxreaxff==0.1.0', 'console_scripts', 'jaxreaxff')())
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jaxreaxff-0.1.0-py3.8.egg/jaxreaxff/driver.py", line 342, in main
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jaxreaxff-0.1.0-py3.8.egg/jaxreaxff/optimizer.py", line 846, in train_FF
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/api.py", line 427, in cache_miss
    out_flat = xla.xla_call(
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/core.py", line 1690, in bind
    return call_bind(self, fun, *args, **params)
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/core.py", line 1702, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/core.py", line 601, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/dispatch.py", line 142, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/linear_util.py", line 272, in memoized_fun
    ans = call(fun, *args)
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/dispatch.py", line 212, in lower_xla_callable
    device = _xla_callable_device(nreps, backend, device, arg_devices)
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/dispatch.py", line 362, in _xla_callable_device
    return xb.get_backend(backend).get_default_device_assignment(1)[0]
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/lib/xla_bridge.py", line 301, in get_backend
    return _get_backend_uncached(platform)
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jax-0.3.1-py3.8.egg/jax/_src/lib/xla_bridge.py", line 291, in _get_backend_uncached
    raise RuntimeError(f"Backend '{platform}' failed to initialize: "
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Backend 'gpu' failed to initialize: NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/work/dumortil/miniconda3/envs/jax-env/bin/jaxreaxff", line 33, in <module>
    sys.exit(load_entry_point('jaxreaxff==0.1.0', 'console_scripts', 'jaxreaxff')())
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jaxreaxff-0.1.0-py3.8.egg/jaxreaxff/driver.py", line 342, in main
  File "/work/dumortil/miniconda3/envs/jax-env/lib/python3.8/site-packages/jaxreaxff-0.1.0-py3.8.egg/jaxreaxff/optimizer.py", line 846, in train_FF
RuntimeError: Backend 'gpu' failed to initialize: NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
lcdumort commented 2 years ago

I've found the error. In helper.py, the gpu is hardcoded. Replacing it by 'cpu' and reinstalling solved my issue, however this is inconsistent with the flag.