tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.48k stars 320 forks source link

Fix `dtype` and `assign*` in `AutocastVariable`. #1136

Open copybara-service[bot] opened 3 days ago

copybara-service[bot] commented 3 days ago

Fix dtype and assign* in AutocastVariable.

The dtype property would return to true dtype of the variable, instead of the dtype of the value that you get explicitly via .value() or implicitly by doing any operation.

This would cause seemingly correct things like this to fail with a dtype mismatch:

y = variable * tf.cast(x, variable.dtype)

Forcing users to write workarounds like:

v = variable.value()
y = variable * tf.cast(x, v.dtype)

Additionally, assign, assign_add, assign_sub expected the value to be of the true dtype, not the cast dtype.

This would cause seemingly correct things like this to fail with a dtype mismatch:

variable.assign(variable * factor)

(This is a common use case for non-trainable variables.)

Forcing users to write workarounds like:

variable.assign(tf.cast(variable * factor, variable.dtype))

This changes fixes these issues to make autocasting fully transparent:

Note that this is consistent with how autocasting works in Keras 3.