google / qkeras

QKeras: a quantization deep learning library for Tensorflow Keras
Apache License 2.0
533 stars 102 forks source link

print_qstats raises ValueError with custom objects #90

Closed boubasse closed 2 years ago

boubasse commented 2 years ago

Hi,

I'm encountering an issue with the print_qstats function while using a custom kernel constraint. Suppose I define a custom constraint :

class CustomConstraint(tf.keras.constraints.Constraint):
        def __call__(self, w):
            # some logic here
            return something

Now if I define a model using functional API and the custom constraint above, I will get :

inp = Input(shape=shape)
out= QDense(units=units, kernel_constraint=CustomConstraint(), kernel_quantizer=quantized_bits(8)(inp)
model = Model(inp, out)

Using print_qstats will raise a ValueError caused by the CustomConstraint object ... The problem comes from the clone_model used in the print_qstats (https://github.com/google/qkeras/blob/1334b68ef2a9a95655cc5ee6ee7001e98453e7ff/qkeras/utils.py#L1057). I see that we can specify our custom objects. However, they cannot be injected into the print_qstats function as is.

Do you have a quick solution for this? The formal solution would be to pass the custom objects dictionary in the function or as a global if needed elsewhere.

Thank you for your time,

boubasse commented 2 years ago

I've just found the solution to my issue haha ...

One must define the following decorator at the begining of the CustomConstaint:

@tf.keras.utils.register_keras_serializable(package="Custom", name="CustomConstraint")
class CustomConstraint(tf.keras.constraints.Constraint):
        def __call__(self, w):
            # some logic here
            return something

Doing so will automatically add the custom constraint to the custom objects. Warning this only work with functions but for layers, you must redefine the get_config function.

Hope this will help other people.