Open wuzhiyang2016 opened 4 years ago
Hi,
As you can see in the method merge_quantize_weights(), the parameters are quantized. So both parameters and activations are quantized.
But both effects parameter quantization and activation quantization are removed from backward (only from backward) by that line: y.data.copy_(yq.data). The forward still has the effects of quantization because the contents of yq is coped to y and is given the next layer.
This is because the all operations that was done to produce yq is skipped from the graph (used for backward) by that one line of operation. But teh contents of yq itself is stored and propagated to next layer because it is copied inside y.
During backward whatever operations was used to create the tensor y will replayed back to produce the gradients - because tensor y is what we are propagating to the next layer. But the content of tensor y is yq - so the next layer will use the quantized values. This is what Straight through Estimation is. It's a bit complicated.
To understand this better you can study how PyTorch stores graph operations and how it can be skipped from being used in backward.
yes, the operation of merge_quantizeweights() is out of calculate graph, It has no effect on the gradient of model parameters,,,assume a situation like this: last layer's output y, the current layer is a conv layer, only have weights but no bias, if we don't do **y.data.copy(yq.data)** , this conv layer' weight gradient is y instead of yq, it's wrong, so we correct the weight's gradient by do y.data.copy_(yq.data).
but this correction of weight's gradient is just about activation quantize loss produced. the loss of quantization params should generate an additional loss, we should correct the conv layer's w to wq, as you can see in the google's paper: Quantizing deep convolutional networks for efficient inference: A whitepaper formula (16),
Interesting discussion. Lets continue. You said: we should correct the conv layer's w to wq
I think it is clear that the weights used in forward in this code is quantized, as I pointed out merge_quantize_weights. There is no doubt about that.
I think what you are saying is that the weights that is given out of this code for inference (for example exported via onnx file) should also be quantized. Is my understanding correct?
Let me write it step by step and you tell me which step do you thing there should be a modification:
i am very glad to hear from you~ what i mean is when calculate gradient, we should use wq instead of w. i just check the code of merge_quantize_weights() function , we use the function by conv, weight, bias = self.merge_quantize_weights(qparams, conv, bn) then, fq is calculated by weight and bias. it's what you said: "I think it is clear that the weights used in forward in this code is quantized" ?
if is that, my confusion is merge_quantize_weights() doesn't change the original conv's weight
save model params by float is ok,the tda4 model tool will convert it to int.
From the paper that you quoted, section 3.2: "However, we maintain weights in floating point and update them with the gradient updates. This ensures that minor gradient updates gradually update the weights instead of underflowing. " It seems that in that paper also, the floating points weights are the one that is kept and updated by backpropagation. Quantized wights are used for forward only.
ok, so, jacinto-ai-devkit's implement code doesn't replace w by wq in the backward.
but section3.2, the paper said : "The updated weights are quantized and used for subsequent forward and backward pass computation"
of course, in my experience, it doesn't matter for model precision whether or not to replace wq to w
Okay. It is clearer now.
So what you are saying is that backward computation should use wq,yq (but update w as per the paper). But instead what is happening is backward computation is using w,yq and updates w. (Forward computation uses wq - no confusion about that).
"we should add a line of code: w.data.copy_(wq.data)" We should be careful in doing the above as we don't want to overwrite w (that has to be kept in float). May be a copy should be used.
You also said "it doesn't matter for model precision whether or not to replace wq to w" - are you sure about it?
2018~2019 i participated in reimplemention of google's quantize training paper by pytorch in my working company, we implemented it strictly follow the formula besides with forward and backword. our precision of quantized model is nearly equal to float model in various tasks, such as detection、segment, and so on , in ti's report on https://git.ti.com/cgit/jacinto-ai/pytorch-jacinto-ai-devkit/about/docs/Quantization.md , i don't see big decrease of precison
Interesting! Thanks for pointing this out - something to think about.
I have cleaned up the implementation of STE a bit for better understanding.
yq = layers.PropagateQuantTensorSTE(None)(y, yq) The forward will use yq, but the gradient will flow back skipping the fake quantization part.
See it being used here: https://git.ti.com/cgit/jacinto-ai/pytorch-jacinto-ai-devkit/tree/modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py#n347
This of-course is our version of STE in which the backward doesn't use the quantized weights (and as you pointed out, is a bit different from the paper that you mentioned).
if you have function called "round" and you want to do STE for that, wrap it in the class PropagateQuantTensorSTE: round_g = PropagateQuantTensorSTE(round) The calls to the function round_g() will then happen in the STE way - i.e. backward of that function by be bypassed and gradient will flow back right through. Example: https://git.ti.com/cgit/jacinto-ai/pytorch-jacinto-ai-devkit/tree/modules/pytorch_jacinto_ai/xnn/layers/functional.py#n8
hello, in quant_train_module.py file, i saw a line of code : y.data.copy_(yq.data), this code change the data of relu's output data.data, in order to use it in backword for calculate activation's gradient, which will decrease the loss of activation quantize, but i couldn't see any code which care about weights or bias quantize loss,,,
in my opinion, we should add a line of code: w.data.copy_(wq.data) for conv layer or fully connected layer's weights.
so, what i want to check is whether jacinto-ai-devkit only care about the loss of activation quantize ?