google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
32 stars 14 forks source link

Enable Gemma 2B #75

Closed qihqi closed 3 months ago

qihqi commented 3 months ago

For Gemma 2B we need to change the shardings because the dimension we usually shard, num_kv_heads happens to be 1 for Gemma 2B. So we pick a different one to shard.