Open bferrarini opened 4 years ago
@bferrarini does your code work if you replace DoReFa
with some other quantiser, such as SteSign
?
Hi @AdamHillier ,
I ran the same code using SetSign on the weights and it worked.
10000/1 1s 93us/sample - loss: 1.3581 - accuracy: 0.6566
Test loss: 1.3786100630760192
Test accuracy: 0.6566
Training done in 170.44 seconds for 30 epochs.
Accordingly with my understanding, DoReFa should clip the gradient as expected (link) I implemented a version of DoReFa with gradient clipping. Here the results
DoReFa + MyDoReFa w/ clipping (2 bits): converge slowly, but converge.
10000/ 101us/sample - loss: 2.1806 - accuracy: 0.1226
Test loss: 2.2232713317871093
Test accuracy: 0.1226
Training done in 178.93 seconds for 30 epochs
DoReFa + MyDoReFa w/ clipping (1 bit): converge
10000/ 1s 97us/sample - loss: 1.3921 - accuracy: 0.4745
Test loss: 1.5290546701431273
Test accuracy: 0.4745
Training done in 178.45 seconds for 30 epochs.
Here the clipped DoReFa quantizer I implemented for the tests
@utils.register_keras_custom_object
class MyDoReFaQuantizer(BaseQuantizer):
precision = None
def __init__(self, k_bit: int = 2, **kwargs):
self.precision = k_bit
super().__init__(**kwargs)
def call(self, inputs):
inputs = tf.clip_by_value(inputs, 0, 1.0)
@tf.custom_gradient
def _k_bit_with_identity_grad(x):
def grad(dy):
return _clipped_gradient(x, dy, clip_value=1.0)
n = 2 ** self.precision - 1
return (tf.round(x * n)) / n, grad
outputs = _k_bit_with_identity_grad(inputs)
return super().call(outputs)
def get_config(self):
return {**super().get_config(), "k_bit": self.precision}
Kind regards,
Bruno
Hi @bferrarini ,
Thanks for opening the issue and providing the problematic code! I have successfully reproduced the problem; the model indeed does not seem to train when using the DoReFaQuantizer
for the weights, and does train when using SteSign
instead. You are correct that the DoReFaQuantizer
should clip the gradient, and in fact it does. The plot in our documentation you linked to is actually generated using the quantizer, so it is definitely doing the right thing. Your adapted version, MyDoReFaQuantizer
, does exactly the same thing; the gradient was already clipped (because the inputs are clipped in the forward pass, and therefore so is the gradient), so adding another clip doesn't make a difference. To check this out for yourself, you can use the following code:
import numpy as np
import matplotlib.pyplot as plt
def calculate_activation(function, x):
tf_x = tf.Variable(x)
with tf.GradientTape() as tape:
activation = function(tf_x)
return activation.numpy(), tape.gradient(activation, tf_x).numpy()
def plot_quantizer(quantizer):
x = np.linspace(-2, 2, 500).astype(np.float32)
y, dy = calculate_activation(quantizer, x)
plt.subplot(121)
plt.grid()
plt.plot(x, y)
plt.subplot(122)
plt.grid()
plt.plot(x, dy)
plt.show()
plot_quantizer(MyDoReFaQuantizer())
This will plot the following figure:
Which is indeed the same as the one of the normal DoReFaQuantizer
.
Nevertheless, I'm surprised that the model does not train. Even when I use your MyDoReFaQuantizer
and k_bit = 1
or k_bit=2
, it never obtains more than 10% accuracy.
Could you please send me the code that allowed you to obtain 47% accuracy? I suspect you have accidentally changed something else, which has somehow resolved the problem.
I also had another look at the DoReFa paper, and found that for k_bit > 1
, they actually use a different quantization formula for the weights than for the activations (i.e. different than the one currently in Larq):
Perhaps this explains why the DoReFaQuantizer
doesn't work well when using it on weights with k_bit = 2
, but I'd still expect it to perform better than random (i.e. obtain an accuracy > 10%)...
Hi @jneeven,
You are right in both of the cases. 1) MyDoReFa is equivalent to DoReFa (many thanks for the plot_quantize function, it is handy). 2) I confirm I cannot reproduce the experiment that obtained 47% accuracy. As supposed by you, I probably made some mistakes in the first run. Possibly, I did not use MyDoReFa quantizer in some part of the code.
Kind Regards,
Bruno
I had another look at the DoReFa paper and have concluded that the issues here stem from the fact that they use a different quantization formula for the weights than for the activations (as I mentioned above). This probably leads to gradient issues, preventing your model from training altogether. In case you want to use binary weights, there is no point in quantizing them with DoReFa, as even the authors themselves just resort to SteSign in that case:
If you want to use DoReFa for weights where k_bit >= 2
, you will need to use the formula described in the paper. A hacky implementation could look somewhat like this (I have not verified that everything is correct):
import larq as lq
import tensorflow as tf
def _clipped_gradient(x, dy, clip_min, clip_max):
"""Calculate `clipped_gradient * dy`."""
zeros = tf.zeros_like(dy)
mask = tf.math.logical_and(tf.math.greater(x, clip_min), tf.math.less(x, clip_max))
return tf.where(mask, dy, zeros)
class CustomDoReFaQuantizer(lq.quantizers.DoReFaQuantizer):
r"""DoReFa with different gradient clipping threshold."""
def __init__(self, k_bit: int = 2, mode="activations", **kwargs):
self.precision = k_bit
self.n = 2 ** self.precision - 1
if mode not in ["weights", "activations"]:
raise ValueError(f"DoReFa received unknown mode: {mode}")
self.mode = mode
super().__init__(**kwargs)
def call(self, inputs):
if self.mode == "activations":
@tf.custom_gradient
def quantize_k(x):
x = tf.clip_by_value(x, 0, 1.0)
return (
tf.round(x * self.n) / self.n,
lambda dy: _clipped_gradient(inputs, dy, 0, 1.0),
)
return quantize_k(inputs)
else:
@tf.custom_gradient
def quantize_k(x):
return (
tf.round(x * self.n) / self.n,
lambda dy: x
)
tan = tf.math.tanh(inputs)
fraction = tan / (2 * tf.math.reduce_max(tf.math.abs(tan)))
print(inputs.shape, tan.shape, quantize_k(fraction + 0.5).shape)
return 2.0 * quantize_k(fraction + 0.5) - 1
I tried this with your code and although the accuracy is still very low (I think it was around 14%), it does at least train.
That being said, DoReFa is somewhat outdated at this point, and a much more interesting quantizer to use would be LSQ from the paper Learned Step Size Quantization. It does basically the same as DoReFa (the scalar multiplied can be fused with the batchnorm), but has much better gradients. A hacky implementation of that (which I already had lying around for some toy examples) could like this:
import numpy as np
import tensorflow as tf
from larq import utils
@tf.custom_gradient
def scaled_gradient(x: tf.Tensor, scale: float = 1.0) -> tf.Tensor:
def grad(dy):
# We don't return a gradient for `scale` as it isn't trainable
return (dy * scale, 0.0)
return x, grad
@utils.register_alias("lsq")
@utils.register_keras_custom_object
class LSQ(tf.keras.layers.Layer):
r"""Instantiates a serializable k_bit quantizer as in the LSQ paper.
# Arguments
k_bit: number of bits for the quantization.
mode: either "signed" or "unsigned", reflects the activation quantization scheme to
use. When using this for weights, use mode "weights" instead.
metrics: An array of metrics to add to the layer. If `None` the metrics set in
`larq.context.metrics_scope` are used. Currently only the `flip_ratio` metric is
available.
# Returns
Quantization function
# References
- [Learned Step Size Quantization](https://arxiv.org/abs/1902.08153)
"""
precision = None
def __init__(self, k_bit: int = 2, mode="unsigned", **kwargs):
self.precision = k_bit
self.mode = mode
if mode == "unsigned":
self.q_n = 0.00
self.q_p = float(2 ** self.precision - 1)
elif mode in ["signed", "weights"]:
self.q_p = float(2 ** (self.precision - 1)) - 1
# For signed, we can use the full signed range, e.g. [-2, 1]
if mode == "signed":
self.q_n = -float(2 ** (self.precision - 1))
# For weights, we use a symmetric range, e.g. [-1, 1]
else:
self.q_n = -float(2 ** (self.precision - 1) - 1)
else:
raise ValueError(f"LSQ received unknown mode: {mode}")
super().__init__(**kwargs)
def build(self, input_shape):
self.s = self.add_weight(
name="s",
initializer="ones",
trainable=True,
aggregation=tf.VariableAggregation.MEAN,
)
self._initialized = self.add_weight(
name="initialized",
initializer="zeros",
dtype=tf.dtypes.bool,
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
# Assuming that by num_features they mean all the individual pixels.
# You can also try the number of feature maps instead.
self.g = float(1.0 / np.sqrt(np.prod(input_shape[1:]) * self.q_p))
super().build(input_shape)
def call(self, inputs):
# Calculate initial value for the scale using the first batch
self.add_update(
self.s.assign(
tf.cond(
self._initialized,
lambda: self.s, # If already initialized, just use current value
# Otherwise, use the value below as initialization
lambda: (2.0 * tf.reduce_mean(tf.math.abs(inputs)))
/ tf.math.sqrt(self.q_p),
)
)
)
self.add_update(self._initialized.assign(True))
s = scaled_gradient(self.s, self.g)
rescaled_inputs = inputs / s
clipped_inputs = tf.clip_by_value(rescaled_inputs, self.q_n, self.q_p)
@tf.custom_gradient
def _round_ste(x):
return tf.round(x), lambda dy: dy
return _round_ste(clipped_inputs) * s
def get_config(self):
return {**super().get_config(), "k_bit": self.precision, "mode": self.mode}
I can't make any guarantees regarding LSQ or this code either, but it's definitely worth a try. Thank you for pointing out this issue, I will update the DoReFa documentation to make clear it is intended only for activations in its current state.
Many thanks for the effort in addressing the issue and for providing the code for LSQ. I will try it.
Describe the bug
Hi, I tried to use DoReFa quantiser to train a simple model for CIFAR10, but training failed to converge:
I found that the problem only occurs when DoReFa quantizer is used on the kernel (kernel_quantizer)
This code reproduces the problem. DoReFa is used for both activations and Weights.
Expected behavior
Training converge for any precision set in DoReFa (k_bit parameter)
Environment
TensorFlow version: 2.0.0 Larq version: 0.9.4