zhangxiangxiao / xjax

Simple framework for neural networks using Jax
BSD 3-Clause "New" or "Revised" License
6 stars 2 forks source link

Move mode argument to states or forward function? #2

Open ryanai3 opened 2 years ago

ryanai3 commented 2 years ago

Currently, train vs. test mode is defined at module construction. This means that if a user wants to run in both modes (e.g. evaluate accuracy on validation examples every 1000 training steps) that the user needs to construct the module twice (or write code to implicitly do so). Example version A:

import xnn
train_forward, params, states  = xnn.Sequential(
  xnn.Embed(1024, 32),
  xnn.Conv(32, 64, (3,)),
  xnn.ReLU(),
  xnn.Dropout(0.3),
  xnn.Conv(64, 32, (3,)),
  xnn.ReLU(),
  xnn.Dropout(0.3),
  xnn.Mean(axis=-1),
)
test_forward, _, _  = xnn.Sequential(
  xnn.Embed(1024, 32),
  xnn.Conv(32, 64, (3,)),
  xnn.ReLU(),
  xnn.Dropout(0.3, mode='test'),
  xnn.Conv(64, 32, (3,)),
  xnn.ReLU(),
  xnn.Dropout(0.3, mode='test'),
  xnn.Mean(axis=-1),
)
for (i, example) in enumerate(train_set):
  params = update(params, train_forward, states)
  if i % 1000 == 0:
    all_metrics = [evaluate(params, test_forward, states, batch) for batch in eval_set]
    avg_metrics = average(all_metrics)
    log(avg_metrics, step = i)

A more intelligent user might write:

import xnn
def MyModule(mode='train'):
  return xnn.Sequential(
    xnn.Embed(1024, 32),
    xnn.Conv(32, 64, (3,)),
    xnn.ReLU(),
    xnn.Dropout(0.3, mode = mode),
    xnn.Conv(64, 32, (3,)),
    xnn.ReLU(),
    xnn.Dropout(0.3, mode = mode),
    xnn.Mean(axis=-1),
   )
train_forward, params, state = MyModule(mode='train')
test_forward, _, _ = MyModule(mode='test')

for (i, example) in enumerate(train_set):
  params = update(params, train_forward, states)
  if i % 1000 == 0:
    all_metrics = [evaluate(params, test_forward, states, batch) for batch in eval_set]
    avg_metrics = average(all_metrics)
    log(avg_metrics, step = i)

But this still requires: 1. Creating the module twice, 2. Tracking train_forward and test_forward and handling both, 3. Forces the user to have to handle this themself.

Flax & Haiku handle this by passing in a 'deterministic' or 'is_training' keyword argument on the forward call. Pytorch, equinox, treex handle by calling module.train() and module.eval() (one can imagine an immutable version where: eval_module = module.eval() ... train_module = eval_module.train() or eval_module = set_mode(module, 'eval') ... train_module = set_mode(eval_module, 'train'))

I think it would make more sense to move the train vs test mode argument to the forward function or into the states, or a functional way to set the train mode (e.g. eval_module = set_mode(module, 'eval')) particularly as there are a few different modules that behave differently at train and test time (dropout, batch norm, layerdrop, etc.)

Thoughts?

zhangxiangxiao commented 2 years ago

I think it's an awesome idea! I'm leaning towards having a 'mode' entry in the state dictionary for that. It will have to work well with:

  1. Container module constructors such as xnn.Sequential, xnn.Parallel etc.
  2. Vectorization using xnn.vectorize.

For 1, I'm thinking that we should clear the 'mode' state for all modules when constructing the container, and only change the top-level 'mode' state entry when calling set_mode. In container forward, we insert the 'mode' entry to all the constructed states for submodules.

For 2, we need to let xnn.vectorize ignore the 'mode' entry so that it does not vectorize string states.

Want to prototype this and send a pull request?

zhangxiangxiao commented 2 years ago

A 3rd requirement:

  1. It has to work with jax.jit.

This means that we cannot use Python strings as values for mode. Perhaps an integer ENUM constant should do.