magenta / mt3

MT3: Multi-Task Multitrack Music Transcription
Apache License 2.0
1.43k stars 187 forks source link

Error when trying to load the model #27

Closed levscaut closed 2 years ago

levscaut commented 2 years ago

In order to run the code locally I have cloned the colab notebook, and finally have set up the environment. Yet when running the code from the notebook there was this error occured, and the console output like below:

/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/flax/optim/base.py:52: DeprecationWarning: Use optax instead of flax.optim. Refer to the update guide https://flax.readthedocs.io/en/latest/howtos/optax_update_guide.html for detailed instructions. 'for detailed instructions.', DeprecationWarning) /home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py:183: UserWarning: pjit is an experimental feature and probably has bugs! warn("pjit is an experimental feature and probably has bugs!") /home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/_src/lib/xla_bridge.py:430: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. "jax.host_count has been renamed to jax.process_count. This alias " Traceback (most recent call last): File "/mnt/fast/lwd/aisheet/test.py", line 252, in inference_model = InferenceModel(checkpoint_path, MODEL) File "/mnt/fast/lwd/aisheet/test.py", line 88, in init self.restore_from_checkpoint(checkpoint_path) File "/mnt/fast/lwd/aisheet/test.py", line 134, in restore_from_checkpoint [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0)) File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 522, in from_checkpoint_or_scratch return (self.from_checkpoint(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng) File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 508, in from_checkpoint self.from_checkpoints(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng)) File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 466, in from_checkpoints yield _restore_path(path, restore_cfg) File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 458, in _restore_path fallback_state=fallback_state) File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 880, in restore return self._restore_train_state(state_dict) File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 891, in _restore_train_state train_state, train_state_axes) File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 639, in move_params_to_devices trainstate, = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32)) File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 729, in call return self._pjitted_fn(args) File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 266, in wrapped argsflat, params, , outtree, = infer_params(args, *kwargs) File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 250, in infer_params tuple(isinstance(a, GDA) for a in args_flat)) File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/linear_util.py", line 272, in memoized_fun ans = call(fun, args) File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 385, in _pjit_jaxpr allow_uneven_sharding=False) File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 581, in _check_shapes_against_resources raise ValueError(f"One of {what} was given the resource assignment " ValueError: One of pjit arguments was given the resource assignment of PartitionSpec(None, 'model'), which implies that the size of its dimension 1 should be divisible by 3, but it is equal to 1024

This occurs when executing the line around 252: inference_model = InferenceModel(checkpoint_path, MODEL) I have totally no idea why this happened, hoping you guys could help me work this out, thanks!

levscaut commented 2 years ago

To anyone who come across this bug: I manage to solve this problem by limit the number of cuda device that is visible to python. The pre-trained weight only apply to one gpu so simply add this in the head of your code will solve this: import os os.environ['CUDA_VISIBLE_DEVICES'] = "0"