ayaka14732 / llama-2-jax

JAX implementation of the Llama 2 model
https://arxiv.org/abs/2307.09288
Creative Commons Zero v1.0 Universal
208 stars 23 forks source link

Unable to shard llama 13B (and 70B) on v4-32 TPU #26

Open defdet opened 8 months ago

defdet commented 8 months ago

As mentioned in the title, we're unable to shard any llama model except for 7B on v4-32 TPU (and I think v3-32) with 1-d model parralelism because number of heads isn't divisible by number of devices (16 devices: 40 attention heads, 8 kv heads for llama 13B, 70B). I suggest implementing 2-d model parralelism (it's usefull for mistral as well). I've already implemented it here. In training, it's ~1.2 times slower than 1-d parallel (measured on mistral 7B), but is more flexible and probably could be optimized further. This implementation is basically the "final attempt" of 2-d sharding from this paper