google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
32 stars 14 forks source link

Gemma sharding and test #70

Closed FanhaiLu1 closed 3 months ago

FanhaiLu1 commented 4 months ago

This PR do two things: 1: Gemma replica sharding be consistent as llama 2: Set bfloat16 or float32 in make_env_tiny