Open fmmoret opened 5 months ago
Would anyone be able to point me to pytorch -> jax llama weights conversion and jax -> pytorch llama weights conversion scripts?
@lhao499 & do you have any compatible w num_key_value_heads !== num_attention_heads?
Would anyone be able to point me to pytorch -> jax llama weights conversion and jax -> pytorch llama weights conversion scripts?