magenta / mt3

MT3: Multi-Task Multitrack Music Transcription
Apache License 2.0
1.41k stars 185 forks source link

Model Loading Failed in Colab #159

Open Catoverflow opened 2 months ago

Catoverflow commented 2 months ago

I tried to run mt3 in colab, but it failed. I am not familiar with the DNN libraries so I'm posting steps to reproduce here only.

Steps to Reproduce

  1. Choose T4 GPU in runtime type
  2. Run the cell of Setup Environment It error with
    ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
    tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.17.0 which is incompatible.
  3. Run the cell of Import and Definitions
  4. Run the cell of Load Model with either mt3 or ismir2021, the notebook errored with:

XlaRuntimeError Traceback (most recent call last)

in <cell line: 13>() 11 12 log_event('loadModelStart', {'event_category': MODEL}) ---> 13 inference_model = InferenceModel(checkpoint_path, MODEL) 14 log_event('loadModelComplete', {'event_category': MODEL})

13 frames

in init(self, checkpoint_path, model_type) 85 86 # Restore from checkpoint. ---> 87 self.restore_from_checkpoint(checkpoint_path) 88 89 @property

in restore_from_checkpoint(self, checkpoint_path) 120 def restore_from_checkpoint(self, checkpoint_path): 121 """Restore training state from checkpoint, resets self._predict_fn().""" --> 122 train_state_initializer = t5x.utils.TrainStateInitializer( 123 optimizer_def=self.model.optimizer_def, 124 init_fn=self.model.get_initial_variables,

/usr/local/lib/python3.10/dist-packages/t5x/ in init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types) 1057 self._partitioner = partitioner 1058 self.global_train_state_shape = jax.eval_shape( -> 1059 initialize_train_state, rng=jax.random.PRNGKey(0) 1060 ) 1061 self.train_state_axes = partitioner.get_mesh_axes(

/usr/local/lib/python3.10/dist-packages/jax/_src/ in PRNGKey(seed, impl) 231 and fold_in. 232 """ --> 233 return _return_prng_keys(True, _key('PRNGKey', seed, impl)) 234 235

/usr/local/lib/python3.10/dist-packages/jax/_src/ in _key(ctor_name, seed, impl_spec) 193 f"{ctor_name} accepts a scalar seed, but was given an array of " 194 f"shape {np.shape(seed)} != (). Use jax.vmap for batching") --> 195 return prng.random_seed(seed, impl=impl) 196 197 def key(seed: int | ArrayLike, *,

/usr/local/lib/python3.10/dist-packages/jax/_src/ in random_seed(seeds, impl) 531 # use-case of instantiating with Python hashes in X32 mode. 532 if isinstance(seeds, int): --> 533 seeds_arr = jnp.asarray(np.int64(seeds)) 534 else: 535 seeds_arr = jnp.asarray(seeds)

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/ in asarray(a, dtype, order, copy) 3287 if dtype is not None: 3288 dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] -> 3289 return array(a, dtype=dtype, copy=bool(copy), order=order) 3290 3291

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/ in array(object, dtype, copy, order, ndmin) 3212 raise TypeError(f"Unexpected input type for array: {type(object)}") 3213 -> 3214 out_array: Array = lax_internal._convert_element_type( 3215 out, dtype, weak_type=weak_type) 3216 if ndmin > ndim(out_array):

/usr/local/lib/python3.10/dist-packages/jax/_src/lax/ in _convert_element_type(operand, new_dtype, weak_type) 557 return type_cast(Array, operand) 558 else: --> 559 return convert_element_type_p.bind(operand, new_dtype=new_dtype, 560 weak_type=bool(weak_type)) 561

/usr/local/lib/python3.10/dist-packages/jax/_src/ in bind(self, *args, **params) 414 assert (not config.enable_checks.value or 415 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args --> 416 return self.bind_with_trace(find_top_trace(args), args, params) 417 418 def bind_with_trace(self, trace, args, params):

/usr/local/lib/python3.10/dist-packages/jax/_src/ in bind_with_trace(self, trace, args, params) 418 def bind_with_trace(self, trace, args, params): 419 with pop_level(trace.level): --> 420 out = trace.process_primitive(self, map(trace.full_raise, args), params) 421 return map(full_lower, out) if self.multiple_results else full_lower(out) 422

/usr/local/lib/python3.10/dist-packages/jax/_src/ in process_primitive(self, primitive, tracers, params) 919 return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, *params) 920 else: --> 921 return primitive.impl(tracers, **params) 922 923 def process_call(self, primitive, f, tracers, params):

/usr/local/lib/python3.10/dist-packages/jax/_src/ in apply_primitive(prim, *args, *params) 85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) 86 try: ---> 87 outs = fun(args) 88 finally: 89 lib.jax_jit.swap_thread_local_state_disable_jit(prev)

[... skipping hidden 15 frame]

/usr/local/lib/python3.10/dist-packages/jax/_src/ in backend_compile(backend, module, options, host_callbacks) 236 # TODO(sharadmv): remove this fallback when all backends allow compile 237 # to take in host_callbacks --> 238 return backend.compile(built_c, compile_options=options) 239 240 def compile_or_get_cached(

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

ntamotsu commented 2 months ago

same issue

goel-raghav commented 1 month ago

same problem here

goel-raghav commented 1 month ago

seems like either changing to CPU or changing:

!python3 -m pip install jax[cuda12_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f to !python3 -m pip install nest-asyncio pyfluidsynth==1.3.0 -e .

fixes the problem. Not sure which one exactly because I ran out of colab GPU hours or something.

Catoverflow commented 1 month ago

@goel-raghav Yes, after changing the code it says WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu., and the model loading succeeded.

laqieer commented 1 month ago works for GPU, which is much faster than CPU.