Closed frreiss closed 5 years ago
It turns out that you can fold a batch normalization into a depthwise convolution that comes before it. You just need to be clever about indexing into the weights of the convolution op. Third PR implements that rewrite in addition to the one described at the top of this issue.
With PRs #18, #24, and #26 we can now fold the batch norm operations in MobileNet and MobileNetV2 in two ways.
The rewrites
fold_batch_norms
andfold_old_batch_norms
in the TensorFlow Graph Transform Tool do not work when the batch normalization layer is immediately after aDepthwiseConv2D
layer. As a result, these rewrites do not work with MobileNetV2 or any model that embeds MobileNetV2. This seems like a rather significant oversight, given that MobileNet and MobileNet-derived models are the most common use case for these kinds of graph-simplifying rewrites. This problem affects several models in the Model Asset ExchangeFolding batch normalization into depthwise convolution is a bit tricky because each coefficient in a depthwise convolution participates in every output. In particular, the formula for a depthwise convolution is:
(see https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d). This reuse of filter elements means that we can't fold a batch normalization that happens after the depthwise convolution into the convolution. Batch normalization multiplies each channel by a different amount (1/stdev of the channel), and there's no one place in the filters where those amounts could be added.
Instead, we need to fuse batch normalization into a Conv2D or DepthwiseConv2D that happens after the normalization. This fusion is a bit more tricky, because batch normalization breaks down into a multiply followed by an add, and there is typically a ReLU before the next convolution. For example, the basic building block of MobileNet v1 is
The second BN in this chain is covered by the existing rewrite. We need to fold the first BN into the 1x1 Conv2D that happens after it. That chunk
breaks down to
So we need to pull the multiply into the Conv2D. Another way to write the above sequence of ops is:
As it happens,
m
is always >= 0, since it's equal to 1/stdev. So, switching back to operator notation, we just need to turninto a single Conv2D and rewrite the Add(b) to Add(b/m).
The equation for Conv2D is:
(see https://www.tensorflow.org/api_docs/python/tf/nn/conv2d). Collapsing down the striding parts to f_i and f_j, we have:
So the equation for a Conv2D on top of a multiplication by m is:
So we just need to multiply every filter element in
filter[_, _, q, _]
bym[q]
for each value ofq
. The same principle applies to DepthwiseConv2D.Description of work to address this problem:
rewrite.py
that replicates the current functionality of thefold_batch_norms
andfold_old_batch_norms
rewrites in the Graph Transform Tool.