google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.9k stars 233 forks source link

New interface for spectral normalization #71

Open shoyer opened 4 years ago

shoyer commented 4 years ago

I noticed earlier today that Haiku has SpectralNormalization -- very cool!

I'm interested in implementing an improved version, which does a much better job estimating the norm for convolutional layers and should converge to the correct answer for any linear operator. The trick is to use auto-diff to calculate the transpose of the linear operator. In contrast, the current implementation is only accurate for dense matrices.

Here's my implementation in pure JAX: https://nbviewer.jupyter.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc

My question: how can I implement this in Haiku?

Cogitans commented 4 years ago

Hi! I wrote the current Haiku implementation. I'll answer the questions in reverse order :)

How do I call jax.vjp on module?

If you look at https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/stateful.py you can see how we define a bunch of wrappers around Jax functions to work with Haiku. There's a lot of code but the idea is simple: temporarily grab the global state and thread it through the Jax function inputs, then make it global state again within the function (and reverse when returning). We don't have a wrapper around vjp right now (we have one for grad), but it shouldn't be too hard to do.

how can I implement this in Haiku?

I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?

As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs? The approximation we use is that used by SNGAN (https://arxiv.org/pdf/1802.05957.pdf) and BigGAN, and we know it remains quite stable on accelerators, I'd be curious if you have run any experiments checking the exact approach's numerics.

shoyer commented 4 years ago

I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?

For use in neural network training, I think you would still want to estimate the vector corresponding to the largest singular in an online fashion.

Here's a clearer way to separate the logic:


def _l2_normalize(x, eps=1e-4):
  return x * jax.lax.rsqrt((x ** 2).sum() + eps)

def _l2_norm(x):
  return jnp.sqrt((x ** 2).sum())

def _power_iteration(A, u, n_steps=10):
  """Update an estimate of the first right-singular vector of A()."""
  def fun(u, _):
    v, A_transpose = jax.vjp(A, u)
    u, = A_transpose(v)
    u = _l2_normalize(u)
    return u, None
  u, _ = lax.scan(fun, u, xs=None, length=n_steps)
  return u

def estimate_spectral_norm(f, x, seed=0, n_steps=10):
  """Estimate the spectral norm of f(x) linearized at x."""
  rng = jax.random.PRNGKey(seed)
  u0 = jax.random.normal(rng, x.shape)
  _, f_jvp = jax.linearize(f, x)
  u = _power_iteration(f_jvp, u0, n_steps)
  sigma = _l2_norm(f_jvp(u))
  return sigma

I can imagine estimate_spectral_norm being a separately useful utility, but in a spectral normalization layer, you'd want to save the vector u0 as state on the layer and only use a handful of power iterations in each neural net evaluation.

As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs?

The same approach (but written in a much more awkward/manual way) was used in this ICLR 2019 paper. Numerically, they should be identical. If you're using fully-connected layers, the calculation is exactly the same as the older method, just using autodiff instead of explicit matrix/vector products.

From a fundamental perspective I would guess this is quite efficient and numerically stable on accelerators, because the operation is uses are the exact same as those used at the core of neural net training:

The cost of doing a single power iteration is thus roughly equivalent to that of pushing a single additional example through the neural net.

(The version I wrote in this comment is slightly simpler that the version in the ICRL 2019 paper, because it uses norm(A(u)) rather v @ A(u) to calculate the singular value and only normalizes once per iteration, but I doubt those make much of a difference and are not hard to change.)

chiamp commented 1 year ago

hi @Cogitans, I'm trying to add spectral normalization into Flax and am modeling it after the Haiku version. I had some questions: