google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 225 forks source link

value_and_grad(kernel_fn) not equal to kernel_fn with standard parameterization #123

Closed PythonNut closed 2 years ago

PythonNut commented 3 years ago

I am confused by the behavior of the following snippet of code (the WideResNet from the README with standard parameterization):

import jax
from neural_tangents import stax

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    main = stax.serial(
        stax.Relu(),
        stax.Conv(
            channels, (3, 3), strides, padding="SAME", parameterization="standard"
        ),
        stax.Relu(),
        stax.Conv(channels, (3, 3), padding="SAME", parameterization="standard"),
    )
    shortcut = (
        stax.Identity()
        if not channel_mismatch
        else stax.Conv(
            channels, (3, 3), strides, padding="SAME", parameterization="standard"
        )
    )
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1)):
    blocks = []
    blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
    for _ in range(n - 1):
        blocks += [WideResnetBlock(channels, (1, 1))]
    return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
    return stax.serial(
        stax.Conv(16, (3, 3), padding="SAME", parameterization="standard"),
        WideResnetGroup(block_size, int(16 * k)),
        WideResnetGroup(block_size, int(32 * k), (2, 2)),
        WideResnetGroup(block_size, int(64 * k), (2, 2)),
        stax.AvgPool((8, 8)),
        stax.Flatten(),
        stax.Dense(num_classes, 1.0, 0.0, parameterization="standard"),
    )

_, _, kernel_fn = WideResnet(block_size=4, k=1, num_classes=1)

def kernel_scalar(x, y):
    return kernel_fn(x, y, "ntk")[0, 0]

z = jax.numpy.zeros((1, 32, 32, 3))
print(jax.value_and_grad(kernel_scalar)(z, z)[0])
print(kernel_scalar(z, z))

My understanding is that the two printed values should be the same. However, when I run it, I get two totally different values:

34.41480472358908
64.62813414153004

Is my understanding correct? I have not yet found a simpler network that features this behavior.

Versions:

romanngg commented 3 years ago

Thanks for the repro and good find, it's indeed a bug in our custom differentiation rule for the square root, where we clip the derivative around zero, but also clipped the outputs accidentally as well. I've sent a change to fix it, but needs code review so will likely land tomorrow, in the meantime this is what the change looks like https://github.com/google/neural-tangents/blob/94e7498863916ee4fa44e448ae02fc9682da9f27/neural_tangents/stax.py#L4343

def _sqrt_jvp(tol, primals, tangents):
  x, = primals
  x_dot, = tangents
  safe_tol = max(tol, 1e-30)
  square_root = _sqrt(x, safe_tol)
+ square_root_out = _sqrt(x, tol)
- return square_root, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)
+ return square_root_out, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)
PythonNut commented 3 years ago

Wow thanks for quickly determining the issue!

PythonNut commented 3 years ago

Hmm after pulling 8b7917f, I see that the values match now, but the grad is all nans. Is that the intended outcome?

romanngg commented 2 years ago

Thanks, I'll need look into this, for the meantime, I suspect it's only happening for zero-value inputs and generally shouldn't be a problem otherwise (but perhaps I'm wrong, so worth double-checking to see if there are still nans or discrepancy in normal inputs like images etc)

PythonNut commented 2 years ago

A much smaller example reproducing the nan issue:

from jax import *
from neural_tangents.stax import *

def f(x): return serial(Conv(1, (3, 3)), Relu(), Flatten())[2](x, x, "ntk")[0][0]
print(grad(f)(jax.numpy.zeros((1, 32, 32, 3))))

I guess this no longer has anything to do with parameterization="standard" (it happens either way), so should this be a new issue?

romanngg commented 2 years ago

I think it's probably the same issue, likely related to differentiating kernel functions with Relu (maybe some other nonlinearities too, will need to look into this) at exactly zero, i.e. jax.numpy.zeros((1, 32, 32, 3)). I agree it's likely not specific to parameterization.

romanngg commented 2 years ago

Thank you for your patience here! I think you were right that there were actually two bugs here.

One was wrong treatment of biases with b_std=None, parameterization='standard' (precisely, in standard parameterization, having no bias (b_std=None) and having a zero-variance bais (b_std=0) is not the same.

The other was that the derivative at x1 or x2 being zero is technically undefined for ReLU and similar activations. We now set the gradient at zero inputs to be 0. Note that this is correct for x1 = x2 = 0, and for x2 = 0, but e.g. dK(x1, x2 != 0)/dx1 is genuinely undefined/infinite at x1 = 0, but we will return 0. While this technically incorrect, this matches JAX's behavior of defining the gradient of non-differentiable functions as the mean subgradient, e.g. in JAX jax.grad(jax.numpy.sign)(0.) == 0., or jax.grad(lambda x: jax.numpy.maximum(x, 0.))(0.) == 0.5, so arguably this is a reasonable value to return.

Hope this helps!

PythonNut commented 2 years ago

Thanks so much for the thorough fix! All of the gradient-related anomalies I've been seeing have gone away. I'll open new issues if I run into more problems in the future.