Open geajack opened 3 years ago
Unfortunately there are several ways in which softmax doesn't play nice in the infinite-width settings that we consider in neural tangents:
1) Suppose you have a softmax over the infinite-width axis somewhere in your network. It is easy to see that the infinite-width output of softmax will be a constant 0 for each of the infinite output units, which is not interesting.
2) If you have a softmax over a finite axis, then outputs in the infinite width limit will have a finite covariance, but AFAIK there is no known closed-form expression for this covariance (but please correct me if I'm wrong). This means we can't compute NNGP and NTK matrices using kernel_fn
in nt.stax
if the network has a softmax anywhere in it.
3) If you put a softmax on top and treat outputs as categorically distributed, then your outputs are no longer Gaussian-distributed, and you can't do efficient inference with nt.predict.gradient_descent_mse_ensemble
and nt.predict.gp_inference
.
For these reasons we don't have a softmax function in nt.stax
. To be clear, here are cases where you can still use it:
a) you can compute empirical kernels using nt.monte_carlo_kernel_fn
, nt.empirical_kernel_fn
with softmax, and in fact with any functions. You can use jax.experimental.stax
or flax
or just jax.numpy
etc to define your networks and pass them to these empirical kernel decorators.
b) You can pass these empirical kernels to nt.predict.gradient_descent_mse
(if you use MSE loss), and nt.predict.gradient_descent
(for any loss, cross-entropy included).
Finally, if you want to use nt.stax
with softmax, but you're OK with kernel_fn
not working, you could add a layer like this into stax.py
:
@layer
@supports_masking(remask_kernel=False)
def Softmax(axis: int = -1) -> InternalLayer:
def fn(x):
return jax.nn.softmax(x, axis=axis)
@_requires(diagonal_spatial=_Diagonal()) # pytype:disable=wrong-keyword-args
def kernel_fn(k: Kernel) -> Kernel:
raise NotImplementedError
return _elementwise(fn, f'Softmax(axis={axis})', kernel_fn)
This way you can use nt.stax
layers and their apply_fn
, init_fn
, empirical kernels etc, but you will get an error as soon as kernel_fn
is called.
Lmk if this helps!
Hi, thanks for the fast answer and sorry for my not-so-fast response.
That does help. The reason I was confused is basically because I'm not interested in the infinite-width case, I currently mainly use this library as a tool for computing empirical NTKs. So I couldn't really see why softmax outputs would be an issue. That makes sense.
To be clear, I don't actually need to implement a custom layer in that way, right? If I understand correctly, I can basically pass any Python function to empirical_kernel_fn
. So I could do something like this (I haven't run this, but I think it's roughly the right idea):
initializer, stax_model, kernel = stax.serial(
stax.Dense(1000, parameterization="standard"),
stax.Relu(),
stax.Dense(10, parameterization="standard")
)
_, w0 = initializer(rng_key)
def model(w, x):
return jax.nn.softmax(stax_model(w, x))
tangent_kernel = nt.empirical_kernel_fn(model)
H = tangent_kernel(train_x, None, w0)
Right?
However, something I'm confused about is that when I do this (using your approach), my tangent kernel is spitting out real number values:
H = np.array(tangent_kernel(kernel_batch.xs, None, w0))
assert H.shape == (100, 100) # kernel_batch is a batch of 100 inputs from MNIST
My understanding is that when the network has an n-dimensional output layer, the tangent kernel is nxn matrix-valued. So I would have expected H
to be a 100x100 matrix of 10x10 matrices.
EDIT: Ah, I believe I have it. I went back to the docs. So to get what I want, if I understand correctly, I should do:
tangent_kernel = nt.empirical_kernel_fn(model, trace_axes=())
And what I was doing before basically gives the traces of the matrices that I expect. Actually, judging by the results I'm getting, does it return the "normalized" traces, i.e. trace(A) / n where A is nxn?
One thing did bother me when reading the docs though. I don't quite understand this passage:
For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you.
Is that relevant to my case?
And what I was doing before basically gives the traces of the matrices that I expect. Actually, judging by the results I'm getting, does it return the "normalized" traces, i.e. trace(A) / n where A is nxn?
Yes, this, and everything above, you are completely correct (you can pass any function)!
One thing did bother me when reading the docs though. I don't quite understand this passage:
For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you.
Is that relevant to my case?
No, this is relevant when the output of your network is for example a list of arrays vs a single array (e.g. if you have a stax.parallel
/ stax.FanOut
layers at the top). In this case, the full empirical kernel will be a list of lists containing covariances between each pairs of list entries, but here we return only the "diagonal" of this list, i.e. covariances of entries with themselves (mildly related, for taking the diagonal of array outputs, you can pass diagonal_axes
argument, and you'll get the diagonal instead of the normalized trace).
Finally, I recommend looking into the vmap_axes
argument in the docs - tl;dr if you don't have interactions between different batch elements and 0
is the leading batch dimension from inputs to outputs, you can set vmap_axes=0
to get a good speedup.
By the way, for what it's worth, I find that in my case the most convenient way to do this is like this:
initializer, model, kernel = stax.serial(
stax.Dense(1000, parameterization="standard"),
stax.Relu(),
stax.Dense(10, parameterization="standard"),
stax._elementwise(jax.nn.softmax, name="softmax", kernel_fn=None)
)
but what about computing the Monte Carlo Estimate for Infinite Width so $\mathbb{E}(softmax(u)softmax(v))$ with $u,v$ from a multivariate Gaussian with a block symmetric covariance ?
yes that's another way to do it, we don't have a super-convenient function for it. We have https://neural-tangents.readthedocs.io/en/latest/monte_carlo.html for MC estimating kernels of any functions, but I assume you want to have u
and v
sampled from the Gaussian with the exact NNGP kernel of the penultimate layer. So you'd need to do something like
_, _, kernel_fn = stax.serial(...) # all layers except softmax)
nngp = kernel_fn(x1, x2, get='nngp')
# Sample u and v from N(0, nngp) and compute the MC kernel
...
ok thanks there is also Exploring Alternatives to Softmax Function which contains a Taylor Approximation of Softmax maybe there is a closed form solution for the taylor approximation. I would have to ask that in mathoverflow maybe.
Is it possible to use this library to implement a network with a Softmax output layer? I was surprised to find this layer type neither implemented nor mentioned in the docs, since it's so common.