kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

to_hf_weights.py cpu assertion error #180

Closed kirchner-jan closed 2 years ago

kirchner-jan commented 2 years ago

When converting a slimmed model to the huggingface format with the CPU, I get the following assertion error:

File "to_hf_weights.py", line 488, in <module>
    save_sharded_to_hf_format(input_ckpt, params, output_path, np_dtype, torch_dtype)
  File "to_hf_weights.py", line 466, in save_sharded_to_hf_format
    save_pytree_as_hf(
  File "to_hf_weights.py", line 382, in save_pytree_as_hf
    x = unshard_leave(x, leave_name, old_shape, np_dtype=np_dtype)
  File "to_hf_weights.py", line 312, in unshard_leave
    assert isinstance(x, jnp.ndarray)
AssertionError

Simply commenting out the line

# assert isinstance(x, jnp.ndarray)

works fine.