CODAIT / graph_def_editor

GraphDef Editor: A port of the TensorFlow contrib.graph_editor package that operates over serialized graphs
Apache License 2.0
31 stars 16 forks source link

Replicate and extend TF fold_batch_norms rewrites #15

Closed frreiss closed 5 years ago

frreiss commented 5 years ago

The rewrites fold_batch_norms and fold_old_batch_norms in the TensorFlow Graph Transform Tool do not work when the batch normalization layer is immediately after a DepthwiseConv2D layer. As a result, these rewrites do not work with MobileNetV2 or any model that embeds MobileNetV2. This seems like a rather significant oversight, given that MobileNet and MobileNet-derived models are the most common use case for these kinds of graph-simplifying rewrites. This problem affects several models in the Model Asset Exchange

Folding batch normalization into depthwise convolution is a bit tricky because each coefficient in a depthwise convolution participates in every output. In particular, the formula for a depthwise convolution is:

output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
     filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
                                     strides[2] * j + rate[1] * dj, k]

(see https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d). This reuse of filter elements means that we can't fold a batch normalization that happens after the depthwise convolution into the convolution. Batch normalization multiplies each channel by a different amount (1/stdev of the channel), and there's no one place in the filters where those amounts could be added.

Instead, we need to fuse batch normalization into a Conv2D or DepthwiseConv2D that happens after the normalization. This fusion is a bit more tricky, because batch normalization breaks down into a multiply followed by an add, and there is typically a ReLU before the next convolution. For example, the basic building block of MobileNet v1 is

3x3 DepthwiseConv2D -> BN -> ReLU -> 1x1 Conv2D -> BN -> ReLU

The second BN in this chain is covered by the existing rewrite. We need to fold the first BN into the 1x1 Conv2D that happens after it. That chunk

BN-> ReLU -> 1x1 Conv2D

breaks down to

Multiply -> Add -> ReLU -> 1x1 Conv2D

So we need to pull the multiply into the Conv2D. Another way to write the above sequence of ops is:

Conv2D(ReLU(mx + b))
    == Conv2D(ReLU(m(x + b/m))
    == Conv2D(m * ReLU(x + b/m)) iff m >= 0 (cell-wise)

As it happens, m is always >= 0, since it's equal to 1/stdev. So, switching back to operator notation, we just need to turn

Multiply -> Conv2D

into a single Conv2D and rewrite the Add(b) to Add(b/m).

The equation for Conv2D is:

output[b, i, j, k] =
    sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
                    filter[di, dj, q, k]

(see https://www.tensorflow.org/api_docs/python/tf/nn/conv2d). Collapsing down the striding parts to f_i and f_j, we have:

output[b, i, j, k] =
    sum_{di, dj, q} input[b, f_i(i, di), f_j(j, dj), q] * filter[di, dj, q, k]

So the equation for a Conv2D on top of a multiplication by m is:

output[b, i, j, k] =
    sum_{di, dj, q} (input[b, f_i(i, di), f_j(j, dj), q] * m[q]) * filter[di, dj, q, k]
    = sum_{di, dj, q} input[b, f_i(i, di), f_j(j, dj), q] * (m[q] * filter[di, dj, q, k])

So we just need to multiply every filter element in filter[_, _, q, _] by m[q] for each value of q. The same principle applies to DepthwiseConv2D.

Description of work to address this problem:

frreiss commented 5 years ago

It turns out that you can fold a batch normalization into a depthwise convolution that comes before it. You just need to be clever about indexing into the weights of the convolution op. Third PR implements that rewrite in addition to the one described at the top of this issue.

frreiss commented 5 years ago

With PRs #18, #24, and #26 we can now fold the batch norm operations in MobileNet and MobileNetV2 in two ways.