Closed PythonNut closed 2 years ago
Thanks for the repro and good find, it's indeed a bug in our custom differentiation rule for the square root, where we clip the derivative around zero, but also clipped the outputs accidentally as well. I've sent a change to fix it, but needs code review so will likely land tomorrow, in the meantime this is what the change looks like https://github.com/google/neural-tangents/blob/94e7498863916ee4fa44e448ae02fc9682da9f27/neural_tangents/stax.py#L4343
def _sqrt_jvp(tol, primals, tangents):
x, = primals
x_dot, = tangents
safe_tol = max(tol, 1e-30)
square_root = _sqrt(x, safe_tol)
+ square_root_out = _sqrt(x, tol)
- return square_root, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)
+ return square_root_out, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)
Wow thanks for quickly determining the issue!
Hmm after pulling 8b7917f, I see that the values match now, but the grad
is all nan
s. Is that the intended outcome?
Thanks, I'll need look into this, for the meantime, I suspect it's only happening for zero-value inputs and generally shouldn't be a problem otherwise (but perhaps I'm wrong, so worth double-checking to see if there are still nans or discrepancy in normal inputs like images etc)
A much smaller example reproducing the nan
issue:
from jax import *
from neural_tangents.stax import *
def f(x): return serial(Conv(1, (3, 3)), Relu(), Flatten())[2](x, x, "ntk")[0][0]
print(grad(f)(jax.numpy.zeros((1, 32, 32, 3))))
I guess this no longer has anything to do with parameterization="standard"
(it happens either way), so should this be a new issue?
I think it's probably the same issue, likely related to differentiating kernel functions with Relu
(maybe some other nonlinearities too, will need to look into this) at exactly zero, i.e. jax.numpy.zeros((1, 32, 32, 3))
. I agree it's likely not specific to parameterization.
Thank you for your patience here! I think you were right that there were actually two bugs here.
One was wrong treatment of biases with b_std=None, parameterization='standard'
(precisely, in standard
parameterization, having no bias (b_std=None
) and having a zero-variance bais (b_std=0
) is not the same.
The other was that the derivative at x1
or x2
being zero is technically undefined for ReLU
and similar activations. We now set the gradient at zero inputs to be 0
. Note that this is correct for x1 = x2 = 0
, and for x2 = 0
, but e.g. dK(x1, x2 != 0)/dx1
is genuinely undefined/infinite at x1 = 0
, but we will return 0
. While this technically incorrect, this matches JAX's behavior of defining the gradient of non-differentiable functions as the mean subgradient, e.g. in JAX jax.grad(jax.numpy.sign)(0.) == 0.
, or jax.grad(lambda x: jax.numpy.maximum(x, 0.))(0.) == 0.5
, so arguably this is a reasonable value to return.
Hope this helps!
Thanks so much for the thorough fix! All of the gradient-related anomalies I've been seeing have gone away. I'll open new issues if I run into more problems in the future.
I am confused by the behavior of the following snippet of code (the WideResNet from the README with standard parameterization):
My understanding is that the two printed values should be the same. However, when I run it, I get two totally different values:
Is my understanding correct? I have not yet found a simpler network that features this behavior.
Versions:
jax
0.2.20
jaxlib
0.1.71+cuda111
neural-tangents
0.3.7