google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Non-trainable layers #10

Closed allenbai01 closed 5 years ago

allenbai01 commented 5 years ago

Thanks for making this great resource available!

I wonder if the layers (such as Conv and Dense) in stax can be specified to be non-trainable? If not, is there a way of modifying the output apply_fn so that the layer becomes non-trainable?

sschoenholz commented 5 years ago

Great question! I think one nice thing about stax (and JAX in general) is that there's not really a notion of "trainable" vs "untrainable" parameters. Everything is just nested tuples / lists / dicts of arrays. This means that if you don't want a layer to be trainable, you should just have to zero the gradients to that layer before passing them to an optimizer.

In stax a convenient way to accomplish this is to introduce a new layer type that we'll call Frozen layers. Frozen layers will act just like regular layers, but they won't be trained during gradient descent. We can make Frozen layers out of regular layers by introducing a function,

from jax import lax
from neural_tangents import stax

@stax._layer
def Frozen(layer):
  init_fn, apply_fn, kernel_fn, = layer

  def frozen_apply_fn(params, xs, **unused_kwargs):
    params = tree_map(lambda x: lax.stop_gradient(x), params)
    return apply_fn(params, xs)

  def frozen_kernel_fn(kernels):
    if kernels.ntk is None:
      return kernel_fn(kernels, None, None)

    raise NotImplementedError()

  return init_fn, frozen_apply_fn, frozen_kernel_fn

This function takes a layer - specified by a triple of (init_fn, apply_fn, kernel_fn) - and modifies the apply_fn to not propagate gradients to its parameters. We alter the kernel function to raise an error if we try to compute the NTK since the NTK will be changed by having a layer be frozen. If you want to compute the NTK you can use the empirical NTK functions. If it would be interesting to you to have the analytic NTK for a frozen layer, I think it should be pretty easy to work out (it might just be the identity function).

If you then want to create a network with a frozen layer you can write,

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512), stax.Erf(),
    Frozen(stax.Dense(512)), stax.Erf(),  # This layer will not train.
    stax.Dense(1)
)

We can try out the network in the Neural Tangents Cookbook by writing the following code once the dataset has been constructed:

from jax.tree_util import tree_map

_, params = init_fn(net_key, (-1, 1))

loss = jit(lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y) ** 2))
grad_loss = jit(lambda params, x, y: grad(loss)(params, x, y))

print('Gradient Norms:')
norms = tree_map(lambda x: float(np.sum(x ** 2)), 
                 grad_loss(params, train_xs, train_ys))
for i in range(0, 5, 2):
  print('\nLayer {}: '.format(i / 2))
  print('(W, b): ', norms[i])

This prints:

Gradient Norms:

Layer 0: 
('(W, b): ', (0.8965191841125488, 0.0))

Layer 1: 
('(W, b): ', (0.0, 0.0))

Layer 2: 
('(W, b): ', (0.44291678071022034, 0.0))

You can see that the gradients of the two parameters in the layer we froze are zero.

allenbai01 commented 5 years ago

That worked well -- thanks a lot!