keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.11k stars 19.36k forks source link

Fix unexpected dtype policy changes when quantization fails #19690

Closed james77777778 closed 3 weeks ago

james77777778 commented 3 weeks ago

This PR fixes unexpected dtype policy changes when quantization fails.

For example, the current codebase will output False in the following snippet:

import numpy as np

from keras import layers

layer = layers.Embedding(10, 16)
layer.build()
x = np.random.randint(0, 9, size=(1, 3))
original_dtype_policy = layer.dtype_policy

try:
    layer.quantize("float8")  # Will fail
except:
    pass
print(original_dtype_policy == layer.dtype_policy)  # False

With this PR, the output will be True.