google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.19k stars 492 forks source link

gemma-2b-it-pytorch on tpu v5p #64

Closed shungcp closed 4 weeks ago

shungcp commented 1 month ago

Hello, Following the steps in README.md to run gemma-2b-it-pytorch on tpu v5p, i get the error, what's wrong here?

 python scripts/run_xla.py --ckpt /tmp/1/gemma-2b-it.ckpt --variant 2b

TypeError: where() received an invalid combination of arguments - got (Tensor, Tensor, tuple), but expected one of:

michaelmoynihan commented 4 weeks ago

Thanks! I created a PR to address this: https://github.com/google/gemma_pytorch/pull/65