Open ryanai3 opened 2 years ago
I think it's an awesome idea! I'm leaning towards having a 'mode' entry in the state dictionary for that. It will have to work well with:
For 1, I'm thinking that we should clear the 'mode' state for all modules when constructing the container, and only change the top-level 'mode' state entry when calling set_mode. In container forward, we insert the 'mode' entry to all the constructed states for submodules.
For 2, we need to let xnn.vectorize ignore the 'mode' entry so that it does not vectorize string states.
Want to prototype this and send a pull request?
A 3rd requirement:
This means that we cannot use Python strings as values for mode. Perhaps an integer ENUM constant should do.
Currently, train vs. test mode is defined at module construction. This means that if a user wants to run in both modes (e.g. evaluate accuracy on validation examples every 1000 training steps) that the user needs to construct the module twice (or write code to implicitly do so). Example version A:
A more intelligent user might write:
But this still requires: 1. Creating the module twice, 2. Tracking train_forward and test_forward and handling both, 3. Forces the user to have to handle this themself.
Flax & Haiku handle this by passing in a 'deterministic' or 'is_training' keyword argument on the forward call. Pytorch, equinox, treex handle by calling
module.train()
andmodule.eval()
(one can imagine an immutable version where:eval_module = module.eval()
...train_module = eval_module.train()
oreval_module = set_mode(module, 'eval')
...train_module = set_mode(eval_module, 'train')
)I think it would make more sense to move the train vs test mode argument to the forward function or into the states, or a functional way to set the train mode (e.g.
eval_module = set_mode(module, 'eval')
) particularly as there are a few different modules that behave differently at train and test time (dropout, batch norm, layerdrop, etc.)Thoughts?