sokrypton / ColabDesign

Making Protein Design accessible to all via Google Colab!
529 stars 118 forks source link

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. #164

Closed johnnytam100 closed 7 months ago

johnnytam100 commented 7 months ago

Hi Sergey! I was trying to run af/examples/AF2Rank.ipynb but met the following DNN library initialization failed error. When I run the first cell under "## rank structures"

NAME = "1mjc"
CHAIN = "A" # this can be multiple chains
NATIVE_PATH = f"{NAME}.pdb"
DECOY_DIR = f"{NAME}"

if save_output_pdbs:
  os.makedirs(f"{NAME}_output",ok_exists=True)

# get data
%shell wget -qnc https://files.ipd.uw.edu/pub/decoyset/natives/{NAME}.pdb
%shell wget -qnc https://files.ipd.uw.edu/pub/decoyset/decoys/{NAME}.zip
%shell unzip -qqo {NAME}.zip

# setup model
clear_mem()
af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])

I met the following error

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-4-5740c8be5964>](https://localhost:8080/#) in <cell line: 17>()
     15 # setup model
     16 clear_mem()
---> 17 af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])

21 frames
[<ipython-input-2-b0744074ff0e>](https://localhost:8080/#) in __init__(self, pdb, chain, model_name, model_names)
     75                  "model_name":model_name,
     76                  "model_names":model_names}
---> 77     self.reset()
     78 
     79   def reset(self):

[<ipython-input-2-b0744074ff0e>](https://localhost:8080/#) in reset(self)
     78 
     79   def reset(self):
---> 80     self.model = mk_af_model(protocol="fixbb",
     81                              use_templates=True,
     82                              use_multimer=self.args["use_multimer"],

[/content/colabdesign/af/model.py](https://localhost:8080/#) in __init__(self, protocol, use_multimer, use_templates, debug, data_dir, **kwargs)
    118     self._model_params, self._model_names = [],[]
    119     for model_name in model_names:
--> 120       params = data.get_model_haiku_params(model_name=model_name, data_dir=data_dir, fuse=True)
    121       if params is not None:
    122         if not self._args["use_multimer"] and not self._args["use_templates"]:

[/content/colabdesign/af/alphafold/model/data.py](https://localhost:8080/#) in get_model_haiku_params(model_name, data_dir, fuse)
     39     with open(path, 'rb') as f:
     40       params = np.load(io.BytesIO(f.read()), allow_pickle=False)
---> 41     return utils.flat_params_to_haiku(params, fuse=fuse)

[/content/colabdesign/af/alphafold/model/utils.py](https://localhost:8080/#) in flat_params_to_haiku(params, fuse)
    108             P[f"{k}/{c}"] = {}
    109             for d in ["bias","weights"]:
--> 110               P[f"{k}/{c}"][d] = jnp.concatenate([L[d],R[d]],-1)
    111           P[f"{k}/center_norm"] = P.pop(f"{k}/center_layer_norm")
    112           P[f"{k}/left_norm_input"] = P.pop(f"{k}/layer_norm_input")

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in concatenate(arrays, axis, dtype)
   1852   k = 16
   1853   while len(arrays_out) > 1:
-> 1854     arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1855                   for i in range(0, len(arrays_out), k)]
   1856   return arrays_out[0]

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in <listcomp>(.0)
   1852   k = 16
   1853   while len(arrays_out) > 1:
-> 1854     arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1855                   for i in range(0, len(arrays_out), k)]
   1856   return arrays_out[0]

[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in concatenate(operands, dimension)
    615     if isinstance(op, Array):
    616       return type_cast(Array, op)
--> 617   return concatenate_p.bind(*operands, dimension=dimension)
    618 
    619 

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind(self, *args, **params)
    384     assert (not config.jax_enable_checks or
    385             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 386     return self.bind_with_trace(find_top_trace(args), args, params)
    387 
    388   def bind_with_trace(self, trace, args, params):

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
    387 
    388   def bind_with_trace(self, trace, args, params):
--> 389     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    390     return map(full_lower, out) if self.multiple_results else full_lower(out)
    391 

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
    819 
    820   def process_primitive(self, primitive, tracers, params):
--> 821     return primitive.impl(*tracers, **params)
    822 
    823   def process_call(self, primitive, f, tracers, params):

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in apply_primitive(prim, *args, **params)
    129   try:
    130     in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
--> 131     compiled_fun = xla_primitive_callable(
    132         prim, in_avals, OrigShardings(in_shardings), **params)
    133   except pxla.DeviceAssignmentMismatchError as e:

[/usr/local/lib/python3.10/dist-packages/jax/_src/util.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    261         return f(*args, **kwargs)
    262       else:
--> 263         return cached(config._trace_context(), *args, **kwargs)
    264 
    265     wrapper.cache_clear = cached.cache_clear

[/usr/local/lib/python3.10/dist-packages/jax/_src/util.py](https://localhost:8080/#) in cached(_, *args, **kwargs)
    254     @functools.lru_cache(max_size)
    255     def cached(_, *args, **kwargs):
--> 256       return f(*args, **kwargs)
    257 
    258     @functools.wraps(f)

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in xla_primitive_callable(prim, in_avals, orig_in_shardings, **params)
    220       return out,
    221   donated_invars = (False,) * len(in_avals)
--> 222   compiled = _xla_callable_uncached(
    223       lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
    224       orig_in_shardings)

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, name, donated_invars, keep_unused, in_avals, orig_in_shardings)
    250       fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
    251       lowering_platform=None)
--> 252   return computation.compile().unsafe_call
    253 
    254 

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in compile(self, compiler_options)
   2204             **self.compile_args)
   2205       else:
-> 2206         executable = UnloadedMeshExecutable.from_hlo(
   2207             self._name,
   2208             self._hlo,

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in from_hlo(***failed resolving arguments***)
   2542           break
   2543 
-> 2544     xla_executable, compile_options = _cached_compilation(
   2545         hlo, name, mesh, spmd_lowering,
   2546         tuple_args, auto_spmd_lowering, allow_prop_to_outputs,

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, _allow_propagation_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values)
   2452       "Finished XLA compilation of {fun_name} in {elapsed_time} sec",
   2453       fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2454     xla_executable = dispatch.compile_or_get_cached(
   2455         backend, computation, dev, compile_options, host_callbacks)
   2456   return xla_executable, compile_options

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks)
    494 
    495   if not use_compilation_cache:
--> 496     return backend_compile(backend, computation, compile_options,
    497                            host_callbacks)
    498 

[/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312   def wrapper(*args, **kwargs):
    313     with TraceAnnotation(name, **decorator_kwargs):
--> 314       return func(*args, **kwargs)
    315     return wrapper
    316   return wrapper

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks)
    462   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    463   # to take in `host_callbacks`
--> 464   return backend.compile(built_c, compile_options=options)
    465 
    466 _ir_dump_counter = itertools.count()

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Do you know how to deal with this error? Thanks!

johnnytam100 commented 7 months ago

transferred to https://github.com/jproney/AF2Rank/issues/8