Open shoyer opened 4 years ago
Hi! I wrote the current Haiku implementation. I'll answer the questions in reverse order :)
How do I call
jax.vjp
on module?
If you look at https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/stateful.py you can see how we define a bunch of wrappers around Jax functions to work with Haiku. There's a lot of code but the idea is simple: temporarily grab the global state and thread it through the Jax function inputs, then make it global state again within the function (and reverse when returning). We don't have a wrapper around vjp right now (we have one for grad), but it shouldn't be too hard to do.
how can I implement this in Haiku?
I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?
As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs? The approximation we use is that used by SNGAN (https://arxiv.org/pdf/1802.05957.pdf) and BigGAN, and we know it remains quite stable on accelerators, I'd be curious if you have run any experiments checking the exact approach's numerics.
I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?
For use in neural network training, I think you would still want to estimate the vector corresponding to the largest singular in an online fashion.
Here's a clearer way to separate the logic:
def _l2_normalize(x, eps=1e-4):
return x * jax.lax.rsqrt((x ** 2).sum() + eps)
def _l2_norm(x):
return jnp.sqrt((x ** 2).sum())
def _power_iteration(A, u, n_steps=10):
"""Update an estimate of the first right-singular vector of A()."""
def fun(u, _):
v, A_transpose = jax.vjp(A, u)
u, = A_transpose(v)
u = _l2_normalize(u)
return u, None
u, _ = lax.scan(fun, u, xs=None, length=n_steps)
return u
def estimate_spectral_norm(f, x, seed=0, n_steps=10):
"""Estimate the spectral norm of f(x) linearized at x."""
rng = jax.random.PRNGKey(seed)
u0 = jax.random.normal(rng, x.shape)
_, f_jvp = jax.linearize(f, x)
u = _power_iteration(f_jvp, u0, n_steps)
sigma = _l2_norm(f_jvp(u))
return sigma
I can imagine estimate_spectral_norm
being a separately useful utility, but in a spectral normalization layer, you'd want to save the vector u0
as state on the layer and only use a handful of power iterations in each neural net evaluation.
As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs?
The same approach (but written in a much more awkward/manual way) was used in this ICLR 2019 paper. Numerically, they should be identical. If you're using fully-connected layers, the calculation is exactly the same as the older method, just using autodiff instead of explicit matrix/vector products.
From a fundamental perspective I would guess this is quite efficient and numerically stable on accelerators, because the operation is uses are the exact same as those used at the core of neural net training:
The cost of doing a single power iteration is thus roughly equivalent to that of pushing a single additional example through the neural net.
(The version I wrote in this comment is slightly simpler that the version in the ICRL 2019 paper, because it uses norm(A(u))
rather v @ A(u)
to calculate the singular value and only normalizes once per iteration, but I doubt those make much of a difference and are not hard to change.)
hi @Cogitans, I'm trying to add spectral normalization into Flax and am modeling it after the Haiku version. I had some questions:
jax.tree_map
to spectral normalize the params? e.g. params = jax.tree_map(lambda x: spectral_normalize(x), params)
lax.stop_gradient
used?u0
and sigma
?
I noticed earlier today that Haiku has SpectralNormalization -- very cool!
I'm interested in implementing an improved version, which does a much better job estimating the norm for convolutional layers and should converge to the correct answer for any linear operator. The trick is to use auto-diff to calculate the transpose of the linear operator. In contrast, the current implementation is only accurate for dense matrices.
Here's my implementation in pure JAX: https://nbviewer.jupyter.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc
My question: how can I implement this in Haiku?
jax.vjp
on Module? I'm guessing (though to be honest I haven't checked yet) that normal JAX function would break, given the way that Haiku adds mutable state.