Hi Sergey! I was trying to run af/examples/AF2Rank.ipynb but met the following DNN library initialization failed error.
When I run the first cell under "## rank structures"
NAME = "1mjc"
CHAIN = "A" # this can be multiple chains
NATIVE_PATH = f"{NAME}.pdb"
DECOY_DIR = f"{NAME}"
if save_output_pdbs:
os.makedirs(f"{NAME}_output",ok_exists=True)
# get data
%shell wget -qnc https://files.ipd.uw.edu/pub/decoyset/natives/{NAME}.pdb
%shell wget -qnc https://files.ipd.uw.edu/pub/decoyset/decoys/{NAME}.zip
%shell unzip -qqo {NAME}.zip
# setup model
clear_mem()
af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])
I met the following error
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
[<ipython-input-4-5740c8be5964>](https://localhost:8080/#) in <cell line: 17>()
15 # setup model
16 clear_mem()
---> 17 af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])
21 frames
[<ipython-input-2-b0744074ff0e>](https://localhost:8080/#) in __init__(self, pdb, chain, model_name, model_names)
75 "model_name":model_name,
76 "model_names":model_names}
---> 77 self.reset()
78
79 def reset(self):
[<ipython-input-2-b0744074ff0e>](https://localhost:8080/#) in reset(self)
78
79 def reset(self):
---> 80 self.model = mk_af_model(protocol="fixbb",
81 use_templates=True,
82 use_multimer=self.args["use_multimer"],
[/content/colabdesign/af/model.py](https://localhost:8080/#) in __init__(self, protocol, use_multimer, use_templates, debug, data_dir, **kwargs)
118 self._model_params, self._model_names = [],[]
119 for model_name in model_names:
--> 120 params = data.get_model_haiku_params(model_name=model_name, data_dir=data_dir, fuse=True)
121 if params is not None:
122 if not self._args["use_multimer"] and not self._args["use_templates"]:
[/content/colabdesign/af/alphafold/model/data.py](https://localhost:8080/#) in get_model_haiku_params(model_name, data_dir, fuse)
39 with open(path, 'rb') as f:
40 params = np.load(io.BytesIO(f.read()), allow_pickle=False)
---> 41 return utils.flat_params_to_haiku(params, fuse=fuse)
[/content/colabdesign/af/alphafold/model/utils.py](https://localhost:8080/#) in flat_params_to_haiku(params, fuse)
108 P[f"{k}/{c}"] = {}
109 for d in ["bias","weights"]:
--> 110 P[f"{k}/{c}"][d] = jnp.concatenate([L[d],R[d]],-1)
111 P[f"{k}/center_norm"] = P.pop(f"{k}/center_layer_norm")
112 P[f"{k}/left_norm_input"] = P.pop(f"{k}/layer_norm_input")
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in concatenate(arrays, axis, dtype)
1852 k = 16
1853 while len(arrays_out) > 1:
-> 1854 arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
1855 for i in range(0, len(arrays_out), k)]
1856 return arrays_out[0]
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in <listcomp>(.0)
1852 k = 16
1853 while len(arrays_out) > 1:
-> 1854 arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
1855 for i in range(0, len(arrays_out), k)]
1856 return arrays_out[0]
[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in concatenate(operands, dimension)
615 if isinstance(op, Array):
616 return type_cast(Array, op)
--> 617 return concatenate_p.bind(*operands, dimension=dimension)
618
619
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind(self, *args, **params)
384 assert (not config.jax_enable_checks or
385 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 386 return self.bind_with_trace(find_top_trace(args), args, params)
387
388 def bind_with_trace(self, trace, args, params):
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
387
388 def bind_with_trace(self, trace, args, params):
--> 389 out = trace.process_primitive(self, map(trace.full_raise, args), params)
390 return map(full_lower, out) if self.multiple_results else full_lower(out)
391
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
819
820 def process_primitive(self, primitive, tracers, params):
--> 821 return primitive.impl(*tracers, **params)
822
823 def process_call(self, primitive, f, tracers, params):
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in apply_primitive(prim, *args, **params)
129 try:
130 in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
--> 131 compiled_fun = xla_primitive_callable(
132 prim, in_avals, OrigShardings(in_shardings), **params)
133 except pxla.DeviceAssignmentMismatchError as e:
[/usr/local/lib/python3.10/dist-packages/jax/_src/util.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
261 return f(*args, **kwargs)
262 else:
--> 263 return cached(config._trace_context(), *args, **kwargs)
264
265 wrapper.cache_clear = cached.cache_clear
[/usr/local/lib/python3.10/dist-packages/jax/_src/util.py](https://localhost:8080/#) in cached(_, *args, **kwargs)
254 @functools.lru_cache(max_size)
255 def cached(_, *args, **kwargs):
--> 256 return f(*args, **kwargs)
257
258 @functools.wraps(f)
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in xla_primitive_callable(prim, in_avals, orig_in_shardings, **params)
220 return out,
221 donated_invars = (False,) * len(in_avals)
--> 222 compiled = _xla_callable_uncached(
223 lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
224 orig_in_shardings)
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, name, donated_invars, keep_unused, in_avals, orig_in_shardings)
250 fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
251 lowering_platform=None)
--> 252 return computation.compile().unsafe_call
253
254
[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in compile(self, compiler_options)
2204 **self.compile_args)
2205 else:
-> 2206 executable = UnloadedMeshExecutable.from_hlo(
2207 self._name,
2208 self._hlo,
[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in from_hlo(***failed resolving arguments***)
2542 break
2543
-> 2544 xla_executable, compile_options = _cached_compilation(
2545 hlo, name, mesh, spmd_lowering,
2546 tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, _allow_propagation_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values)
2452 "Finished XLA compilation of {fun_name} in {elapsed_time} sec",
2453 fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2454 xla_executable = dispatch.compile_or_get_cached(
2455 backend, computation, dev, compile_options, host_callbacks)
2456 return xla_executable, compile_options
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks)
494
495 if not use_compilation_cache:
--> 496 return backend_compile(backend, computation, compile_options,
497 host_callbacks)
498
[/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
312 def wrapper(*args, **kwargs):
313 with TraceAnnotation(name, **decorator_kwargs):
--> 314 return func(*args, **kwargs)
315 return wrapper
316 return wrapper
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks)
462 # TODO(sharadmv): remove this fallback when all backends allow `compile`
463 # to take in `host_callbacks`
--> 464 return backend.compile(built_c, compile_options=options)
465
466 _ir_dump_counter = itertools.count()
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Hi Sergey! I was trying to run
af/examples/AF2Rank.ipynb
but met the following DNN library initialization failed error. When I run the first cell under "## rank structures"I met the following error
Do you know how to deal with this error? Thanks!