Closed levscaut closed 2 years ago
To anyone who come across this bug: I manage to solve this problem by limit the number of cuda device that is visible to python. The pre-trained weight only apply to one gpu so simply add this in the head of your code will solve this: import os os.environ['CUDA_VISIBLE_DEVICES'] = "0"
In order to run the code locally I have cloned the colab notebook, and finally have set up the environment. Yet when running the code from the notebook there was this error occured, and the console output like below:
/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/flax/optim/base.py:52: DeprecationWarning: Use
inference_model = InferenceModel(checkpoint_path, MODEL)
File "/mnt/fast/lwd/aisheet/test.py", line 88, in init
self.restore_from_checkpoint(checkpoint_path)
File "/mnt/fast/lwd/aisheet/test.py", line 134, in restore_from_checkpoint
[restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 522, in from_checkpoint_or_scratch
return (self.from_checkpoint(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng)
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 508, in from_checkpoint
self.from_checkpoints(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng))
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 466, in from_checkpoints
yield _restore_path(path, restore_cfg)
File "/mnt/fast/lwd/aisheet/t5x/utils.py", line 458, in _restore_path
fallback_state=fallback_state)
File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 880, in restore
return self._restore_train_state(state_dict)
File "/mnt/fast/lwd/aisheet/t5x/checkpoints.py", line 891, in _restore_train_state
train_state, train_state_axes)
File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 639, in move_params_to_devices
trainstate, = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
File "/mnt/fast/lwd/aisheet/t5x/partitioning.py", line 729, in call
return self._pjitted_fn(args)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 266, in wrapped
argsflat, params, , outtree, = infer_params(args, *kwargs)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 250, in infer_params
tuple(isinstance(a, GDA) for a in args_flat))
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/linear_util.py", line 272, in memoized_fun
ans = call(fun, args)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 385, in _pjit_jaxpr
allow_uneven_sharding=False)
File "/home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py", line 581, in _check_shapes_against_resources
raise ValueError(f"One of {what} was given the resource assignment "
ValueError: One of pjit arguments was given the resource assignment of PartitionSpec(None, 'model'), which implies that the size of its dimension 1 should be divisible by 3, but it is equal to 1024
optax
instead offlax.optim
. Refer to the update guide https://flax.readthedocs.io/en/latest/howtos/optax_update_guide.html for detailed instructions. 'for detailed instructions.', DeprecationWarning) /home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/experimental/pjit.py:183: UserWarning: pjit is an experimental feature and probably has bugs! warn("pjit is an experimental feature and probably has bugs!") /home/lwd/.conda/envs/lwd_mt3/lib/python3.7/site-packages/jax/_src/lib/xla_bridge.py:430: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. "jax.host_count has been renamed to jax.process_count. This alias " Traceback (most recent call last): File "/mnt/fast/lwd/aisheet/test.py", line 252, inThis occurs when executing the line around 252: inference_model = InferenceModel(checkpoint_path, MODEL) I have totally no idea why this happened, hoping you guys could help me work this out, thanks!