keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Don't double cast layer norm weights in mixed precision #815

Closed mattdangerw closed 1 year ago

mattdangerw commented 1 year ago

Previously, under mixed precision, we would autocast full precision variables to half precision, then manually cast back to full precision.

Possibly this would compile away (no idea), but is certainly very slow in eagerly. We should set autocast=False to avoid the needless type conversion loop.