As mentioned in the title, we're unable to shard any llama model except for 7B on v4-32 TPU (and I think v3-32) with 1-d model parralelism because number of heads isn't divisible by number of devices (16 devices: 40 attention heads, 8 kv heads for llama 13B, 70B). I suggest implementing 2-d model parralelism (it's usefull for mistral as well). I've already implemented it here. In training, it's ~1.2 times slower than 1-d parallel (measured on mistral 7B), but is more flexible and probably could be optimized further.
This implementation is basically the "final attempt" of 2-d sharding from this paper
As mentioned in the title, we're unable to shard any llama model except for 7B on v4-32 TPU (and I think v3-32) with 1-d model parralelism because number of heads isn't divisible by number of devices (16 devices: 40 attention heads, 8 kv heads for llama 13B, 70B). I suggest implementing 2-d model parralelism (it's usefull for mistral as well). I've already implemented it here. In training, it's ~1.2 times slower than 1-d parallel (measured on mistral 7B), but is more flexible and probably could be optimized further. This implementation is basically the "final attempt" of 2-d sharding from this paper