Open adefossez opened 3 years ago
Hi @adefossez ,
Thanks for using the code.
Basically, the STE used in your cited line does not contribute to the actual gradient generation. The workflow is that:
MetaQuant access gradient of quantized weights (grad) in: https://github.com/csyhhu/MetaQuant/blob/50694404fa010958aa24c96b4c870d5e9442e849/meta_utils/helpers.py#L23.
Use the attained gradient of quantized weights for meta gradient generation.
The gradient of pre-quantized weights (attained by STE) will be overrided in:https://github.com/csyhhu/MetaQuant/blob/50694404fa010958aa24c96b4c870d5e9442e849/meta-quantize.py#L239
Actually my implementation in your cited line is a little bit confusing: The forward quantization used incorporates STE. However as I mentioned, this gradient will be overrided by the generated meta gradient.
Hope this can help you. Let me know if anything still confuses.
Best regards, Shangyu
Hi Shangyu,
Thank you for the clarification. However, it seems to me STE is still used.
The output of the Meta network is added to the weights here: https://github.com/csyhhu/MetaQuant/blob/master/meta_utils/meta_quantized_module.py#L89, followed by a STE function here https://github.com/csyhhu/MetaQuant/blob/master/meta_utils/meta_quantized_module.py#L107
In the paper you say this is how the meta network is getting trained through this contribution to the weight. So you replace the weight w
by w + M(g)
with M(g)
the output of the meta network. Now, you still need to apply the quantization function Q
to Q(w + M(g)
. As Q
is not differentiable, you wouldn't get a gradient here for the meta-network, unless you are using STE, which seems to be the case.
Hi @adefossez ,
Thanks for your further pointing out, now I get your point.
Indeed, to retrieve the gradient of M(g)
for learning of meta network, I should use the output of M(g)
and the gradient of quantized weights, as I stated in Eq.8 in my paper.
However in this implementation, it seems STE is still used to attain the gradient of M(g)
. Thanks for your careful understanding.
I will change the implementation in these days.
Best regards, Shangyu
There are two equalities in eq (8).
The first equality is the regular chain rule. However, the gradient of the loss with respect to \tilde{W}
is zero where it is defined because of the quantization step Q
, so that without STE, you cannot use this expression to train the meta-network.
There is a second equality, on the right, in eq (8), where it seems you replace d l / d\tilde{W}
by M_phi(g, \tilde{W})
. However you can only do so if the meta-network is properly trained. At initialization, the output of the meta-network is random, if you were to use the second part of (8) to update your model, your training would completely ignore the actual loss.
It seems to me you have a chicken and egg problem here. If the meta-network was trained, you could use its output as a gradient, but to train it initially, you must use STE.
Actually I have replaced the gradient of \tilde{W}
by M_phi(g, \tilde{W})
as you mentioned. Basically it looks like using the output of a neural network as part of its gradient to train itself (chicken and egg problem). It's very interesting point I think.
Currently this implementation is actually training quantized base model using meta gradient, but training meta network using STE. Anyhow, I think it is very interesting if the meta network can be trained from a random initialization using the gradient provided by its output. And I am working on it. Thanks for your careful understanding.
I don't think it can train with a randomly initialized meta-network, because the meta-network output would be independent of the loss due to the random initialization. It means you would get the same gradient distribution for any loss function, i.e your model would not be at all specific to the given loss and dataset.
Anyway, this would be a different set of experiments. In that sense, I think you should update your paper following our discussion, as you make some misleading claims that you train without STE, while in fact you are using STE, and incorrectly describe to use eq (8) to update the meta network instead of STE.
Hello, I am trying to understand how MetaQuant works. Reading the paper, it seems that STE is usage is not required at all. However reading the code, it seems STE is used in order to derive a gradient: https://github.com/csyhhu/MetaQuant/blob/master/utils/quantize.py#L25 https://github.com/csyhhu/MetaQuant/blob/master/utils/quantize.py#L41
It would be very helpful if you could help me understand this point.