ContinualAI / avalanche

Avalanche: an End-to-End Library for Continual Learning based on PyTorch.
http://avalanche.continualai.org
MIT License
1.78k stars 290 forks source link

[BACKWARD INCOMPATIBLE] Plugins and interactions with strategy and arguments #659

Open AntonioCarta opened 3 years ago

AntonioCarta commented 3 years ago

Plugins are complex objects due to their interaction with the strategy state and the fact that they need to work with any strategy. In general, plugins that only read the strategy state are safe to implement, while plugins that change the state require more care. Additive changes (adding a regularization function, concatenating samples to the training data) are safe to use with any strategy.

Looking at the three large strategy categories:

I think plugins are quite flexible and most users are using them successfully. I don't see any easy way to restrict their usage without severely limiting their usefulness. I'm open to suggestions here.

As a general rule:

We also need to make some changes to make them easier to understand.

Disable callbacks

A big limitation of the plugin system is that the set of callbacks is fixed. Some strategies may skip some stages and break plugins as a result. The only example up to now is DeepSLDA, which does not use backprop, but there may be others in the future. It would be better if we could somehow add and remove callbacks for custom strategies. If a plugin implements a callback that the strategy does not support, it should raise an error.

Right now, I'm not sure how we can implement such a system. I would like to avoid complex solutions based on metaprogramming/reflection as they would complicate plugins too much.

Interaction between plugins and train arguments

We have the possibility to pass arguments to train. Right now, this is only used to change dataloader's arguments (batch_size, num_workers...). PROBLEM: if a plugin override a value, the arguments are ignored!

for example:

strat = Naive(..., plugin=replay)
strat.train(exp, batch_size=128)

in this example, batch_size is ignored because replay overrides the dataloader. This is very weird and unexpcted for the user, and it will probably go unnoticed. I see two solutions:

If you have a better solution I'm happy to hear it.

Mathieu4141 commented 3 years ago

Disable callbacks

A big limitation of the plugin system is that the set of callbacks is fixed. Some strategies may skip some stages and break plugins as a result. The only example up to now is DeepSLDA, which does not use backprop, but there may be others in the future. It would be better if we could somehow add and remove callbacks for custom strategies. If a plugin implements a callback that the strategy does not support, it should raise an error.

Right now, I'm not sure how we can implement such a system. I would like to avoid complex solutions based on metaprogramming/reflection as they would complicate plugins too much.

For strategies inheriting from BaseStrategy it should be doable.

For eg the BaseStrategy could have an ClassVar DISABLED_CALLBACKS, and it would check during in the constructor that the plugins do not overwrite the specified callbacks.

Small POC:


class BaseStrategy:
    DISABLED_CALLBACKS: ClassVar[Sequence[str]] = ()

    def __init__(self, ...):
        ...
        self._alert_disabled_plugins_callbacks()

    def _alert_disabled_plugins_callbacks(self):
        for disabled_callback_name in self.DISABLED_CALLBACKS:
            for plugin in self.plugins:
                plugin_callback = getattr(plugin, disabled_callback_name)
                callback_class_origin = plugin_callback.__qualname__.split('.')[0]
                if (
                    not isinstance(plugin, EvaluationPlugin)
                    and callback_class_origin != "StrategyPlugin"
                ):
                    logging.warning(
                        f"{plugin} seems to use the callback "
                        f"{disabled_callback_name} wich is disabled by {self}"
                    )

class NoBeforeBackwardStrategy(BaseStrategy):
    DISABLED_CALLBACKS: ClassVar[Sequence[str]] = ('before_backward',)

Now if we try to use the new strategy with a LwF plugin that expects to use the callback, we get a warning

NoBeforeBackwardStrategy(
        model,
        Adam(model.parameters()),
        plugins=[LwFPlugin(), ReplayPlugin()]
)

WARNING:root:<avalanche.training.plugins.lwf.LwFPlugin object at 0x133a01bb0> seems to use the callback before_backward wich is disabled by <__main__.NoBeforeBackwardStrategy object at 0x133a019a0>

We wouldn't get warnings for metrics (which should be acceptable), and we might get some unwanted warnings in case a generic plugin classes overwrites more methods that what the final child plugin will actually use.

What do you think?

AntonioCarta commented 3 years ago

@Mathieu4141 I like your proposal. It should be possible to check also the metrics with the same mechanism by checking self.evaluator.metrics.

Let me know if you want to implement this solution.

Mathieu4141 commented 3 years ago

Yes I'd like to implement that, I'll open a PR during the week 🙂