Open AntonioCarta opened 3 years ago
Disable callbacks
A big limitation of the plugin system is that the set of callbacks is fixed. Some strategies may skip some stages and break plugins as a result. The only example up to now is DeepSLDA, which does not use backprop, but there may be others in the future. It would be better if we could somehow add and remove callbacks for custom strategies. If a plugin implements a callback that the strategy does not support, it should raise an error.
Right now, I'm not sure how we can implement such a system. I would like to avoid complex solutions based on metaprogramming/reflection as they would complicate plugins too much.
For strategies inheriting from BaseStrategy
it should be doable.
For eg the BaseStrategy
could have an ClassVar DISABLED_CALLBACKS
, and it would check during in the constructor that the plugins do not overwrite the specified callbacks.
Small POC:
class BaseStrategy:
DISABLED_CALLBACKS: ClassVar[Sequence[str]] = ()
def __init__(self, ...):
...
self._alert_disabled_plugins_callbacks()
def _alert_disabled_plugins_callbacks(self):
for disabled_callback_name in self.DISABLED_CALLBACKS:
for plugin in self.plugins:
plugin_callback = getattr(plugin, disabled_callback_name)
callback_class_origin = plugin_callback.__qualname__.split('.')[0]
if (
not isinstance(plugin, EvaluationPlugin)
and callback_class_origin != "StrategyPlugin"
):
logging.warning(
f"{plugin} seems to use the callback "
f"{disabled_callback_name} wich is disabled by {self}"
)
class NoBeforeBackwardStrategy(BaseStrategy):
DISABLED_CALLBACKS: ClassVar[Sequence[str]] = ('before_backward',)
Now if we try to use the new strategy with a LwF plugin that expects to use the callback, we get a warning
NoBeforeBackwardStrategy(
model,
Adam(model.parameters()),
plugins=[LwFPlugin(), ReplayPlugin()]
)
WARNING:root:<avalanche.training.plugins.lwf.LwFPlugin object at 0x133a01bb0> seems to use the callback before_backward wich is disabled by <__main__.NoBeforeBackwardStrategy object at 0x133a019a0>
We wouldn't get warnings for metrics (which should be acceptable), and we might get some unwanted warnings in case a generic plugin classes overwrites more methods that what the final child plugin will actually use.
What do you think?
@Mathieu4141 I like your proposal. It should be possible to check also the metrics with the same mechanism by checking self.evaluator.metrics
.
Let me know if you want to implement this solution.
Yes I'd like to implement that, I'll open a PR during the week 🙂
Plugins are complex objects due to their interaction with the strategy state and the fact that they need to work with any strategy. In general, plugins that only read the strategy state are safe to implement, while plugins that change the state require more care. Additive changes (adding a regularization function, concatenating samples to the training data) are safe to use with any strategy.
Looking at the three large strategy categories:
adapted_dataset
are safe, changing theDataLoader
is not.I think plugins are quite flexible and most users are using them successfully. I don't see any easy way to restrict their usage without severely limiting their usefulness. I'm open to suggestions here.
As a general rule:
adapted_dataset
anddataloader
are changed by the plugin. Unfortunately, I don't see any easy way to catch misuses at runtime (e.g. a user that tries to override the dataloader when using replay plugins). In any case, most problems are probably obvious or highlighted by the documentation.training.plugins
but organize the documentation better, liketorch.nn
does.We also need to make some changes to make them easier to understand.
Disable callbacks
A big limitation of the plugin system is that the set of callbacks is fixed. Some strategies may skip some stages and break plugins as a result. The only example up to now is DeepSLDA, which does not use backprop, but there may be others in the future. It would be better if we could somehow add and remove callbacks for custom strategies. If a plugin implements a callback that the strategy does not support, it should raise an error.
Right now, I'm not sure how we can implement such a system. I would like to avoid complex solutions based on metaprogramming/reflection as they would complicate plugins too much.
Interaction between plugins and train arguments
We have the possibility to pass arguments to
train
. Right now, this is only used to change dataloader's arguments (batch_size, num_workers...). PROBLEM: if a plugin override a value, the arguments are ignored!for example:
in this example,
batch_size
is ignored because replay overrides the dataloader. This is very weird and unexpcted for the user, and it will probably go unnoticed. I see two solutions:train
If you have a better solution I'm happy to hear it.