google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.24k stars 499 forks source link

MPS (Apple Silicon) Support #11

Open dsanmart opened 6 months ago

dsanmart commented 6 months ago

Will there be MPS support for the Gemma models? It would enable access to a larger community.

lamroger commented 6 months ago

Took a look, a few things.

Linux / mps support looks to be in progress still https://github.com/pytorch/pytorch/issues/81224 so running in a container isn't ready yet.

MPS has some limitations around complex tensors atm. Since gemma uses RoPE, it uses complex tensors and errors out if you run it locally.

https://github.com/pytorch/pytorch/pull/116764/files#diff-fe061f10677283971d77576718d3a04a00b2225d72c043fd59222a882b92c64bR654

https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L426

Running locally with python scripts/run.py --ckpt gemma-2b-it.ckpt --variant 2b --device mps

bghira commented 5 months ago

pytorch 2.3 has bf16 and complex tensor support, and dockerised containers now work @lamroger