huggingface / optimum-tpu

Google TPU optimizations for transformers models
Apache License 2.0
66 stars 17 forks source link

Weights upcasted to `float32` at load time #23

Closed mfuntowicz closed 5 months ago

mfuntowicz commented 5 months ago

When loading Gemma model(s) the weights are upcasted towards float32 representation increasing the memory requirements by twice and slowing down computations

tengomucho commented 5 months ago

This is being fixed in https://github.com/huggingface/optimum-tpu/pull/21, you can see the change here: https://github.com/huggingface/optimum-tpu/pull/21/commits/cd992264a856bb5fbc9670df5e158179d914899b

tengomucho commented 5 months ago

Fixed in a different way, anyway it's in https://github.com/huggingface/optimum-tpu/pull/21.