ContinualAI / avalanche

Avalanche: an End-to-End Library for Continual Learning based on PyTorch.
http://avalanche.continualai.org
MIT License
1.75k stars 287 forks source link

Inputs to the `forward`- and `criterion`-functions in the BaseStrategy-object are limited #596

Closed GMvandeVen closed 3 years ago

GMvandeVen commented 3 years ago

Hi, I was looking into implementing a VAE-based generative classifier (as described here: https://arxiv.org/abs/2104.10093) in Avalanche as a new baseline, but I ran into the issue that in the BaseStrategy-object, the inputs into self.criterion are restricted to self.logits and self.mb_y (see this line: https://github.com/ContinualAI/avalanche/blob/c7aea348c6558efd5de471bf659ca23ef5e5de94/avalanche/training/strategies/base_strategy.py#L411) and the inputs into self.forward are restricted to self.mb_x and self.mb_task_id (see here: https://github.com/ContinualAI/avalanche/blob/c7aea348c6558efd5de471bf659ca23ef5e5de94/avalanche/training/strategies/base_strategy.py#L555). To implement the above generative classifier strategy, self.criterion would also need self.mb_x as input and self.forward would also need self.mb_y as input. Although in general, I think it would be good to allow for all data that might be unpacked by _unpack_minibatch as inputs into both these functions. (Another related note is that I would rename self.logits (i.e. the output of self.forward) to self.model_output or something similar.)

Also for implementing generative replay strategies, it would be helpful if the inputs to self.forward and self.criterion could be more flexible.

I appreciate that this suggestion would break some things, but I think there are ways around that, although it makes things a bit more involved. The change to self.criterion would break the direct use of nn.CrossEntropyLoss as a loss function, but a way around this might be to introduce a new base-class LossFunction, or similar, which for the current implementations could just be a wrapper around nn.CrossEntropyLoss. The change to self.forward would break the current models that expect a fixed amount of inputs, but a way around this could be to introduce a new base-class AvalancheModel, or similar, which would require models to have forward functions that accept more flexible inputs. I think both of these would have other advantages as well. But of course there might be other ways around these issues that might be better/easier. And there might be other things that might break that I didn’t think of..

Apologies for the lengthy issue!! I hope it is clear/makes sense, but please let me know if not.

AntonioCarta commented 3 years ago

Custom Forward

Right now forward is a method of the BaseStrategy:

    def forward(self):
        if isinstance(self.model, MultiTaskModule):
            return self.model.forward(self.mb_x, self.mb_task_id)
        else:  # no task labels
            return self.model.forward(self.mb_x)

This makes it easy to customize the method for your own needs by defining a subclass:

class MyStrategy(BaseStrategy):
    def forward(self):
        return self.model(self.mb_x, self.mb_y)

In general, I think it's impossible to guess in advance what inputs your model will need. Avalanche tries to provide a reasonable default, and allows you to customize it.

Although in general, I think it would be good to allow for all data that might be unpacked by _unpack_minibatch as inputs into both these functions

It is available through self.mbatch. We don't give it a name because we can't know what it is. Usually, I just set an additional property in my custom strategy:

    # this is a method of a BaseStrategy's subclass
    @property
    def class_mask(self):
        assert len(self.mbatch) == 5
        return self.mbatch[3]

Custom Loss

Instead, for the loss function, we don't have any explicit method to override it, and this is something that we need to address. Currently, Avalanche plugins add regularization losses inside before_backward (i.e. before the backpropagation). Example from LwF:

    def before_backward(self, strategy, **kwargs):
        """
        Add distillation loss
        """
        alpha = self.alpha[strategy.training_exp_counter] \
            if isinstance(self.alpha, (list, tuple)) else self.alpha
        penalty = self.penalty(strategy.logits, strategy.mb_x, alpha)
        strategy.loss += penalty

We could provide a criterion method inside the BaseStrategy to define more easily custom losses:

class MyStrategy(BaseStrategy):
    def forward(self):
        # if you need it, here you can also compute additional outputs
        self.something_else = self.generator_model(self.mb_x, self.mb_y, self.aux)
        return self.model(self.mb_x, self.mb_y)

    def criterion(self):
        if self.training:
            return cross_entropy(...) + kd(...)
        else:
            return cross_entropy(...)

@GMvandeVen what do you think of this approach? It seems easier to me because we can support pytorch losses automatically but we leave the users more freedom to define custom losses by subclassing and overriding the corresponding methods.

In summary:

GMvandeVen commented 3 years ago

Thanks @AntonioCarta for your quick response, that’s very helpful! That helps me to understand the organization of the training loop better. Yes, I think the approach and modifications you suggest make sense and should work, certainly for most purposes.

Perhaps the only downside of this approach is that, if I understand it correctly, it is only possible to change the (inputs to the) forward- or criterion-function by defining a new strategy and overriding the methods. Could there maybe be an advantage of making it possible to also change these functions by defining a new plug-in?