hpi-xnor / BMXNet

(New version is out: https://github.com/hpi-xnor/BMXNet-v2) BMXNet: An Open-Source Binary Neural Network Implementation Based on MXNet
Apache License 2.0
349 stars 95 forks source link

QFullyConnected mutates weights in forward mode #30

Closed analog-cbarber closed 5 years ago

analog-cbarber commented 6 years ago

It appears that QFullyConnected will quantize the weights into the input weight tensor even when not in training mode. It is a bit surprising to have the operator mutate the weight tensor like this. If this is not a bug, this behavior should be clearly documented in the operator description.

>>> W = mx.random.normal(shape=(4,64))
>>> W
[[-2.11374566e-01 -7.62305379e-01  5.35988867e-01 -4.48779255e-01
  -1.31812811e-01 -9.39604759e-01  8.35931599e-01  3.33142400e-01
   1.88610386e-02 -4.82557148e-01  7.96305165e-02  2.28180361e+00
  -6.15945280e-01 -6.84607267e-01 -1.09104884e+00  1.61278319e+00
  -4.62019652e-01  1.15526748e+00  1.00387812e+00  6.51997179e-02
  -2.08390355e-01 -5.01749277e-01 -8.90954554e-01  1.93811685e-01
  -3.69189644e+00  1.33110189e+00 -9.12502468e-01  6.57643005e-02
  -1.09751368e+00 -9.91342485e-01 -1.21290708e+00  6.61847472e-01
  -2.20562726e-01  1.52223051e-01  6.54029310e-01 -4.36110109e-01
   6.78317189e-01 -4.90361512e-01 -1.13644254e+00 -1.15610786e-01
  -1.22058713e+00  5.92948437e-01  1.15824485e+00  8.71689692e-02
  -1.06366360e+00  7.94529617e-01 -1.97111309e+00  4.99654144e-01
   7.78103471e-01 -9.06336457e-02  1.36469460e+00  9.52839136e-01
   7.28555679e-01  2.49940425e-01 -3.67091447e-01  2.34669849e-01
   1.23725939e+00  7.70155713e-02  7.63777673e-01 -2.70560622e-01
  -3.04230303e-02 -5.69541216e-01 -4.35389206e-02 -2.02609086e+00]
 [ 1.19611490e+00 -4.55334902e-01  1.75488353e-01 -1.21917176e+00
  -2.98362315e-01 -1.93958059e-01  1.80431500e-01 -1.58335018e+00
  -1.61724344e-01  1.60257757e-01 -3.08117604e+00 -1.37699589e-01
   2.87654519e-01 -1.49461657e-01 -3.96128535e-01  2.14600182e+00
   4.24181908e-01  3.94673020e-01 -1.84842292e-02 -1.17970586e+00
   9.18654054e-02  8.21183503e-01 -2.83561778e+00  1.59463704e-01
  -6.14835680e-01 -1.63099396e+00 -8.21941197e-02  1.87127218e-02
   1.70377719e+00 -2.62416095e-01  1.14750612e+00 -7.83303559e-01
   6.05888128e-01  6.09731436e-01 -2.25910731e-02 -9.14791644e-01
   1.02548385e+00 -3.56592703e+00  1.29791510e+00  4.42981362e-01
  -7.46885777e-01  1.02512610e+00 -7.97469497e-01 -3.27157199e-01
   6.98440671e-01 -8.62959862e-01 -9.37188506e-01  1.27880239e+00
  -2.33837748e+00 -3.82108897e-01 -6.23956919e-02 -8.48336697e-01
  -9.68048036e-01 -2.98008025e-01  9.47782397e-02  4.11213666e-01
   4.17784423e-01  9.08401981e-02  2.04132140e-01  1.24544680e+00
   5.85648179e-01  6.69055283e-01 -1.39358103e+00  5.04939497e-01]
 [ 9.20787096e-01  9.13565159e-01  1.52436423e+00 -7.06564724e-01
  -4.66956079e-01  3.56256664e-02 -4.71516877e-01  4.01355475e-01
   5.14568210e-01 -8.81631017e-01 -4.48225707e-01 -1.55657268e+00
  -1.13136508e-01 -1.88967620e-03 -1.17206562e+00 -5.11925995e-01
  -1.65847576e+00 -3.38403374e-01  1.68761730e+00 -1.71251976e+00
  -1.30054665e+00  1.02668285e-01 -2.58739978e-01 -6.71934068e-01
   1.46498546e-01  3.35435748e-01  4.68158603e-01 -3.10511351e-01
  -1.41961992e+00 -5.00294864e-01  9.75775719e-01 -2.83480972e-01
  -1.56842291e-01  8.74613285e-01  8.50444660e-02 -1.82479694e-01
  -4.73392665e-01 -9.15907085e-01 -8.06360245e-01 -6.43816411e-01
  -6.91942811e-01  7.48873591e-01 -7.36202061e-01  7.20680177e-01
   8.20632339e-01  1.83446443e+00 -1.45658314e+00 -6.36922061e-01
   1.22709394e+00  8.55946958e-01  9.65574801e-01  5.68778694e-01
  -1.02208860e-01  1.36076117e+00  3.91971320e-01  3.41300428e-01
   3.70879382e-01 -1.07574785e+00  1.05239189e+00  8.15406501e-01
   1.07894875e-01 -3.64720911e-01  2.12204620e-01  9.17427897e-01]
 [ 1.15112793e+00 -2.00505897e-01  9.29222584e-01 -7.11516738e-02
  -8.05326998e-01  1.32869601e+00 -9.25439358e-01 -6.03633940e-01
  -2.48306438e-01  3.89059186e-01  9.18562055e-01 -3.78619999e-01
   1.00211866e-01  7.20045030e-01 -4.44365352e-01 -2.64862776e+00
  -1.18471313e+00  1.16577756e+00 -6.09033763e-01  9.64892924e-01
   1.43267602e-01  1.88822067e+00 -2.35196084e-01 -2.37704784e-01
  -1.39442050e+00 -2.20630479e+00 -2.18164459e-01  1.50160953e-01
  -7.75259554e-01  6.50879443e-01 -8.46705019e-01  1.04838349e-01
  -7.26454630e-02 -7.21233130e-01 -9.52106655e-01  1.59448719e+00
  -9.63124096e-01 -1.21563292e+00 -6.99505329e-01 -1.20860569e-01
  -2.78758675e-01  7.75578797e-01 -4.66849864e-01 -6.78790927e-01
   1.25006175e+00 -2.72246242e-01 -1.13920772e+00  1.05596157e-02
   8.58640492e-01 -3.42171431e-01  1.21449947e+00  2.70008862e-01
  -1.82649934e+00  4.53750230e-02 -6.52859628e-01  3.11093211e-01
  -4.11078960e-01 -1.70676017e+00 -3.61594371e-02  2.44527057e-01
   2.20263505e+00 -9.06375766e-01 -1.25763461e-01  4.25077640e-02]]
<NDArray 4x64 @cpu(0)>
>>> mx.ndarray.QFullyConnected(data=mx.ndarray.ones((1,64)), weight=W, 
                                            num_hidden=4, act_bit=2, weight_bit=2)
[[-2.6666646 -3.3333306  2.0000014 -5.333333 ]]
<NDArray 1x4 @cpu(0)>
>>> W
[[-0.3333333  -0.3333333   0.33333337 -0.3333333  -0.3333333  -1.
   1.          0.33333337  0.33333337 -0.3333333   0.33333337  1.
  -0.3333333  -0.3333333  -1.          1.         -0.3333333   1.
   1.          0.33333337 -0.3333333  -0.3333333  -1.          0.33333337
  -1.          1.         -1.          0.33333337 -1.         -1.
  -1.          0.33333337 -0.3333333   0.33333337  0.33333337 -0.3333333
   0.33333337 -0.3333333  -1.         -0.3333333  -1.          0.33333337
   1.          0.33333337 -1.          0.33333337 -1.          0.33333337
   0.33333337 -0.3333333   1.          1.          0.33333337  0.33333337
  -0.3333333   0.33333337  1.          0.33333337  0.33333337 -0.3333333
  -0.3333333  -0.3333333  -0.3333333  -1.        ]
 [ 1.         -0.3333333   0.33333337 -1.         -0.3333333  -0.3333333
   0.33333337 -1.         -0.3333333   0.33333337 -1.         -0.3333333
   0.33333337 -0.3333333  -0.3333333   1.          0.33333337  0.33333337
  -0.3333333  -1.          0.33333337  1.         -1.          0.33333337
  -0.3333333  -1.         -0.3333333   0.33333337  1.         -0.3333333
   1.         -0.3333333   0.33333337  0.33333337 -0.3333333  -1.
   1.         -1.          1.          0.33333337 -0.3333333   1.
  -0.3333333  -0.3333333   0.33333337 -1.         -1.          1.
  -1.         -0.3333333  -0.3333333  -1.         -1.         -0.3333333
   0.33333337  0.33333337  0.33333337  0.33333337  0.33333337  1.
   0.33333337  0.33333337 -1.          0.33333337]
 [ 1.          1.          1.         -0.3333333  -0.3333333   0.33333337
  -0.3333333   0.33333337  0.33333337 -1.         -0.3333333  -1.
  -0.3333333  -0.3333333  -1.         -0.3333333  -1.         -0.3333333
   1.         -1.         -1.          0.33333337 -0.3333333  -0.3333333
   0.33333337  0.33333337  0.33333337 -0.3333333  -1.         -0.3333333
   1.         -0.3333333  -0.3333333   1.          0.33333337 -0.3333333
  -0.3333333  -1.         -1.         -0.3333333  -0.3333333   0.33333337
  -0.3333333   0.33333337  1.          1.         -1.         -0.3333333
   1.          1.          1.          0.33333337 -0.3333333   1.
   0.33333337  0.33333337  0.33333337 -1.          1.          1.
   0.33333337 -0.3333333   0.33333337  1.        ]
 [ 1.         -0.3333333   1.         -0.3333333  -1.          1.
  -1.         -0.3333333  -0.3333333   0.33333337  1.         -0.3333333
   0.33333337  0.33333337 -0.3333333  -1.         -1.          1.
  -0.3333333   1.          0.33333337  1.         -0.3333333  -0.3333333
  -1.         -1.         -0.3333333   0.33333337 -0.3333333   0.33333337
  -1.          0.33333337 -0.3333333  -0.3333333  -1.          1.
  -1.         -1.         -0.3333333  -0.3333333  -0.3333333   0.33333337
  -0.3333333  -0.3333333   1.         -0.3333333  -1.          0.33333337
   1.         -0.3333333   1.          0.33333337 -1.          0.33333337
  -0.3333333   0.33333337 -0.3333333  -1.         -0.3333333   0.33333337
   1.         -1.         -0.3333333   0.33333337]]
<NDArray 4x64 @cpu(0)>
analog-cbarber commented 6 years ago

This does not happen if both act_bit and weight_bit are one.

yanghaojin commented 6 years ago

it will always quantize the weights and input. we use a different way to do the binarization other than quantization. The quantization on bit widths ranging from 2 to 31 bit is available mainly for scientific purpose. There is no speed or memory gain (rather the opposite since there are conversion steps) as the quantized values are still stored in full precision float variables. But it is different once we do it in 1-bit refer here

analog-cbarber commented 6 years ago

The issue is not whether the quantization is necessary but the fact that the input weight tensor is mutated when not in training mode. You would usually not expect your input weight tensor to be modified when not doing training. I don't know if this was done intentionally as an optimization or was inadvertent, but it should either be fixed or documented.

If someone were to try to share weights between a Q* layer and a regular layer, they would be surprised to find that the Q layer was modifying the weights just running forward.

analog-cbarber commented 6 years ago

Once concern about this behavior is that it might hamper learning when using weight_bit of 2 or more. I haven't tried this scenario yet, but it seems like it would be a problem:

If this really is the behavior, then I would consider this a bug because it would amount to implicitly quantizing the gradient at the same bit width as the weights.

yanghaojin commented 6 years ago

Hi, there is a small issue in the model saving, thus we always apply quantization in the forward pass to avoid this problem. it's right, change the weights in forward pass when not in training is somehow not a correct behavior, I will try to fix this issue.

analog-cbarber commented 6 years ago

Once again, the issue is not whether you need to quantize weights in forward mode. Unless the weights have already been binarized/quantized and you are not doing training, then you clearly do need to quantize weights. The issue is whether you should blow away the input weights when you do so as rather than simply making a copy of the weights when quantizing. If you overwrite the input weights, then you are effectively quantizing the weight updates themselves, which will greatly hamper training; you generally want to apply weight updates to the unquantized weights (or at least with weights with a higher bit-width), so you don’t want to overwrite them. A standard way to implement this is to include the forward quantization in the computation graph explicitly but pass the gradients unchanged through the quantization step:

Forward: W --> Quantize --> QFC Backward: W <-- Identity <-- QFC

It is convenient to build the quantization of the weights into the QFC/QConv operators as you do, but it should retain the above semantics.

Now it may turn out that the standard MXNet training loops will already copy the weights, so that this doesn’t matter, but it is still surprising to find that the QFullyConnected function modifies its inputs.

I think this issue is strictly with the behavior of the Q* operator itself and not with MXNet, so I don’t see why it would be necessary to modify any source code other than your own, so I don’t see why there would be any merge issues.

From: Haojin Yang notifications@github.com Reply-To: hpi-xnor/BMXNet reply@reply.github.com Date: Monday, January 22, 2018 at 8:40 AM To: hpi-xnor/BMXNet BMXNet@noreply.github.com Cc: Christopher Barber Christopher.Barber@analog.com, Author author@noreply.github.com Subject: Re: [hpi-xnor/BMXNet] QFullyConnected mutates weights in forward mode (#30)

So here is some more description of the issue.

· first of all, we would like to avoid modifications in mexnet source files, so that we can keep the merge process as easy as possible.

· Why we put the weights_quantization in the forward pass? since it is hard to do in the backward process after the weight update (via optimizer in mxnet), we don't want to modify the optimizer functions. If there is the better solution, we are more than happy to improve this.

· but why we have to quantize the weights in the prediction mode? since the last weight update will change the quantized/binarized weights before saving the model, so if we load a model with QConv or Qfc layers, those weights are not directly in a correct state (for 1-bit, they are mostly 0.99998, -1.0001....), we have to quantize them one more time (change to -1 and 1). However, if you intended to use a binarized_model using model_converter, there is no such problem. Since we have already converted the weights into 0 and 1, and saved them in the unsigned int variable. I know changing weights in the non-training mode is not beautiful, we are working on this now. If anybody has a better idea, pls share with us.

· In terms of training quantized weights, the situation mentioned above is true. There might be a lot of updates for some weights accomplished nothing. This is the reason why binary (quantized) Network is hard to train, and got convergence slower. But do we really want to update most of the quantized weights during every backward pass? let's use the example of 2 bits, if we modify the gradient and always update the weights in the quantized value interval (like from 1.->.3333->-.3333->-1.->...), so we can imagine that the vibration of weights might be very large, which may harder the training in such way.

· we are happy to have further discussion and suggestions!

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHubhttps://github.com/hpi-xnor/BMXNet/issues/30#issuecomment-359425910, or mute the threadhttps://github.com/notifications/unsubscribe-auth/AHEKczGD8WnV1nla8DYUqbKWexvpRxxCks5tNI_EgaJpZM4Rk9vc.

yanghaojin commented 6 years ago

we got the issue mentioned above, this is why I delete the last post, will fix it as soon as possible.

analog-cbarber commented 6 years ago

Note that the fact that the multi-bit weight quantizer is not idempotent (issue #35) also adds to this problematic behavior because it could cause the outputs to change on repeated calls with the same input and weight tensors!

yanghaojin commented 5 years ago

please check our new version BMXNet v2: https://github.com/hpi-xnor/BMXNet-v2