Closed mfuntowicz closed 5 months ago
This is being fixed in https://github.com/huggingface/optimum-tpu/pull/21, you can see the change here: https://github.com/huggingface/optimum-tpu/pull/21/commits/cd992264a856bb5fbc9670df5e158179d914899b
Fixed in a different way, anyway it's in https://github.com/huggingface/optimum-tpu/pull/21.
When loading Gemma model(s) the weights are upcasted towards
float32
representation increasing the memory requirements by twice and slowing down computations