Closed batrasakshi closed 2 years ago
I am trying out the demo notebook without tpu backend and updated
"cores_per_replica": 1, "per_replica_batch": 1,
in params
While executing network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1],shards_out=cores_per_replica) i am getting error : Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)
network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1],shards_out=cores_per_replica)
Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)
Full error :
AssertionError Traceback (most recent call last) File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:217, in read_ckpt_lowmem(pytree, dir, shards_in, shards_out, load_opt) 216 try: --> 217 unsharded = _unshard() 218 except AssertionError: File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:210, in read_ckpt_lowmem.<locals>._unshard() 208 unsharded.append(x) --> 210 assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}" 211 device_index += 1 AssertionError: Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096) During handling of the above exception, another exception occurred: AssertionError Traceback (most recent call last) Input In [31], in <cell line: 1>() ----> 1 network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1],shards_out=cores_per_replica) File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:222, in read_ckpt_lowmem(pytree, dir, shards_in, shards_out, load_opt) 220 del pytree['opt_state'] 221 old_flattened, structure = jax.tree_flatten(pytree) --> 222 unsharded = _unshard() 224 loaded_pytree = jax.tree_unflatten(structure, unsharded) 226 if not load_opt: File /opt/conda/miniconda3/lib/python3.8/site-packages/mesh_transformer/checkpoint.py:210, in read_ckpt_lowmem.<locals>._unshard() 207 x = reshard(x, old_flattened[device_index].shape) 208 unsharded.append(x) --> 210 assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}" 211 device_index += 1 213 print(f"read from disk/gcs in {time.time() - start:.06}s") AssertionError: Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)
Worked by specifying shards_in
I am trying out the demo notebook without tpu backend and updated
in params
While executing
network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1],shards_out=cores_per_replica)
i am getting error :Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)
Full error :