google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.09k stars 2.66k forks source link

Gradient leakage through masked convolutions #12159

Open phlippe opened 1 year ago

phlippe commented 1 year ago

Description

Short summary: When applying a mask to a convolution kernel, the gradients are unexpectedly non-zero for masked input elements. Minimal example in Colab: Open In Collab

The bug/unexpected behavior is related to the situation where we want to mask certain elements in a convolutional filter. For example, for autoregressive image modeling, one would not want to look at 'future' pixels in order to allow for an efficient training via teacher forcing. However, when investigating the gradients through a masked convolution, it turns out that the gradient for masked input elements is non-zero, despite the kernel value being zero for these pixels.

Example

Consider an image of size 5x5 on which we apply a 3x3 kernel. We mask out the last row of the kernel, hence only taking into account the pixels above and on the same row of a reference pixel. We then apply this masked convolution to the input image, and determine the gradients of the input image through this operation. For clarity, we use the sum of the center pixel features as 'loss', which corresponds to visualizing the receptive field of the center pixel.

import jax
import jax.numpy as jnp
from jax import lax
from jax import random

# Setting up input, kernel, mask
feat_dim = 128
rng = random.PRNGKey(42)
rng, inp_rng, kernel_rng = random.split(rng, 3)
inp = random.normal(inp_rng, (1, 5, 5, feat_dim))
kernel = 1/jnp.sqrt(9*feat_dim) * random.normal(kernel_rng, (3, 3, feat_dim, feat_dim))
# Masking last row of 3x3 filter: [[1, 1, 1], [1, 1, 1], [0, 0, 0]]
mask = jnp.concatenate([jnp.ones((2, 3, feat_dim, feat_dim)),
                        jnp.zeros((1, 3, feat_dim, feat_dim))], axis=0)

# Applying convolution with mask applied to kernel
def apply_masked_conv(input_tensor, kernel_tensor):
  kernel_tensor *= mask
  y = lax.conv_general_dilated(
          input_tensor,
          kernel_tensor,
          (1, 1),
          "SAME",
          dimension_numbers=('NHWC', 'HWOI', 'NHWC')
      )
  return y

# Example gradient function with respect to the center pixel at position 2,2
grad_fn = jax.grad(lambda input_tensor: apply_masked_conv(input_tensor, kernel)[:,2,2,:].sum())
grads = grad_fn(inp)

# Printing absolute gradient per pixel, averaged over input channels
print(jnp.abs(grads).mean(axis=-1))

Because of the masking, one would expect that only the pixels inp[:,1:3,1:4] have non-zero gradients. However, the gradients are as follows:

[[[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
  [0.0000000e+00 2.4213897e-01 2.4514729e-01 2.7408689e-01 0.0000000e+00]
  [0.0000000e+00 2.9646739e-01 2.5383139e-01 2.4915963e-01 0.0000000e+00]
  [0.0000000e+00 6.6611392e-08 1.0057556e-07 1.1386874e-07 0.0000000e+00]
  [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]]

The elements [6.6611392e-08 1.0057556e-07 1.1386874e-07] are the gradients for the pixels that are masked out by the kernel, and are non-zero despite the kernel value being zero for them. Note that this occurs on a GPU, but the elements are zeros if the code is run on CPU, or a very small channel size is used (e.g. 4). Additionally, the gradients of the kernels are correctly zero for the masked elements.

Relevance

Masked convolutions are used in autoregressive convolutional models such as PixelCNN. When implemented with masked convolutions, one sees that the receptive field of a pixel spans across the whole image. A full implementation example with the unexpected behavior can be found here. This leads to gradients through masked features when stacking multiple masked layers, and introduced a gradient bias to the optimization of the kernels through masked inputs. A simple example, where the optimization goes wrong because of this, can be found here. In practice, the leakage is not directly noticeable, since the leaked gradients are usually a few magnitudes smaller than the non-masked elements.

A work-around, also for efficiency, is to use smaller kernels in the first place (e.g. [2x3] instead of masked [3x3]). But considering that masked convolutions are supported in, e.g., flax, I think it would be good to point out this gradient leakage. Please let me know in case this bug/behavior has already been discussed in detail somewhere in the JAX documentation, a search through it and the current github issues didn't show any results.

What jax/jaxlib version are you using?

jax v0.3.14, jaxlib v0.3.14

Which accelerator(s) are you using?

GPU (behavior as expected on CPU)

Additional System Info

Tested on Colab and locally (Ubuntu, Python 3.9, NVIDIA GTX1080Ti)

jakevdp commented 1 year ago

Thanks for the report, this is an interesting bug. Strangely enough it looks like jit-compiling grad_fn is sufficient to fix the issue:

grads = jax.jit(grad_fn)(inp)
print(jnp.abs(grads).mean(axis=-1))
# [[[0.         0.         0.         0.         0.        ]
#   [0.         0.24213897 0.24514729 0.2740869  0.        ]
#   [0.         0.29646742 0.2538314  0.24915966 0.        ]
#   [0.         0.         0.         0.         0.        ]
#   [0.         0.         0.         0.         0.        ]]]
phlippe commented 1 year ago

Hi @jakevdp, thanks for looking at it and interesting that jit-compiling fixes it for this single layer! Interestingly, however, the optimization through two layers still showed issues under jit-compiling the update step and the same holds for the PixelCNN. Under further investigation, it seems that the output of the convolution is already affected by this. For example, consider the input image which has 0s on the first three rows, and 1s on the last two rows for all channels. Applying a filter with the mask above should output zero features for the center pixel, since we convolve a filter with mask [[1 1 1], [1 1 1], [0 0 0]] with an input [[0 0 0], [0 0 0], [1 1 1]] for this single pixel. In other words, we only multiply 0s with 1s. However, the output is again non-zero for these pixels:

inp = jnp.concatenate([jnp.zeros((1, 3, 5, feat_dim)),
                       jnp.ones((1, 2, 5, feat_dim))], axis=1)
out = jax.jit(apply_masked_conv)(inp, kernel)
print(jnp.abs(out).mean(axis=-1))

Output on a GPU (Colab):

[[[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
  [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
  [1.4640318e-07 1.2042801e-07 1.2042801e-07 1.2042801e-07 1.2423152e-07]
  [3.6361670e-01 4.2430118e-01 4.2430118e-01 4.2430118e-01 3.6237997e-01]
  [5.5979115e-01 6.9566023e-01 6.9566023e-01 6.9566023e-01 5.3750622e-01]]]

On a CPU, the output is, as expected, zero for the center row:

[[[0.         0.         0.         0.         0.        ]
  [0.         0.         0.         0.         0.        ]
  [0.         0.         0.         0.         0.        ]
  [0.3636167  0.42430118 0.42430118 0.42430118 0.36237997]
  [0.55979115 0.69566035 0.69566035 0.69566035 0.5375063 ]]]

Interestingly, the upper two rows are all zeros even on a GPU, which suggests that the behavior only occurs when the input is non-zero.

phlippe commented 1 year ago

In a related discussion, it was suggested that the reason for this difference might be the convolution kernel chosen by XLA.