google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

Can we get kernel_fn from flax models? #99

Closed chenwydj closed 3 years ago

chenwydj commented 3 years ago

Currently, most neural tangent examples get kernel_fn from stax.serial.

Is there any more advanced way to get kernel_fn from complex models?

For example, can we get kernel_fn from flax's nn.Module? Like this one (from here)

class ResNetBlock(nn.Module):
  """ResNet block."""
  filters: int
  conv: ModuleDef
  norm: ModuleDef
  act: Callable
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x,):
    residual = x
    y = self.conv(self.filters, (3, 3), self.strides)(x)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters, (3, 3))(y)
    y = self.norm(scale_init=nn.initializers.zeros)(y)

    if residual.shape != y.shape:
      residual = self.conv(self.filters, (1, 1),
                           self.strides, name='conv_proj')(residual)
      residual = self.norm(name='norm_proj')(residual)

    return self.act(residual + y)

Thank you!

sschoenholz commented 3 years ago

Hey @chenwydj, thanks for checking out NT!

There are two different ways of answering your question. First, it is possible right now to combine Flax and NT, to compute the empirical neural tangent kernel, or to compute the neural tangent kernel estimated via MC sampling. Basically, the functions nt.empirical_kernel_fn and nt.monte_carlo_kernel_fn both accept an (init_fn, apply_fn) pair which are more-or-less the equivalent of model.init and model.apply in Flax. There are a few small issues of convention that must be hacked around; first, flax initialization function expects examples whereas we call initialization functions with shape information only. Second, we pass an rng key to the model as an rng keyword argument and so the __call__ function of the flax model must accept an rng key explicitly or otherwise accept unused keyword arguments.

Here is a colab notebook that I hacked together getting this working using a slight variation on the resnet you posted above.

On the other hand, you might have meant "can you compute the analytic kernel function from a flax network?" This is a significantly more difficult problem. We are working on getting something together that can partially compute these kernels but I don't have a firm timeframe for exactly when we might have something ready for testing. As we make progress here, I'll try to remember to bump this issue.

romanngg commented 3 years ago

Just in case, a brief comment on stax.serial, note that we also have stax.parallel, and in principle can implement arbitrary topologies, e.g. the resnet in the example https://github.com/google/neural-tangents#infinitely-wideresnet (But this is likely not what you were asking for, so otherwise, +1 to Sam, supporting Flax is a great, but also very ambitious feature)

chenwydj commented 3 years ago

Thank you very much @sschoenholz and @romanngg for your great answers!

I think I will stick to stax.serial for now! :)