aimagelab / mammoth

An Extendible (General) Continual Learning Framework based on Pytorch - official codebase of Dark Experience for General Continual Learning
MIT License
532 stars 95 forks source link

How to Use on_EWC/ DER++ to Handle Regression Tasks #48

Open JieDengsc opened 3 weeks ago

JieDengsc commented 3 weeks ago

Specifically, I don't know how to modify the code. Because my task now is regression.

loribonna commented 3 weeks ago

Hi @JieDengsc , Mammoth is not really designed to handle regression but depending on your use case it may be easy to adapt.

Since the task is regression I guess you probably want to define a domain-il task, since without labels you wouldn't know how to split into separate tasks. Taking as an example the "perm-mnist" dataset (in datasets/perm_mnist.py) you could create a new file datasets/<your_dataset>.py and in it define a class that inherits from ContinualDataset and defines:

Besides this, you will need to modify

If you don't want a "domain-il" setting and want to split data according to some other policy, I still suggest to define a "domain-il" dataset and splitting the data in the get_data_loaders.

We plan in the future to introduce some regression tasks. Let me know if yours is publicly available so that we may take a look into it.

JieDengsc commented 3 weeks ago

Hi @loribonna , Thanks for your reply and suggestion, I will try it.

In addition, for ewc_on.py, when calculating the fish matrix, why do you add exp_cond_prob in the code? fish += exp_cond_prob * self.net.get_grads() ** 2 According to the paper, only need to sum the squares of the gradients and take the average at the end.

Please let me know if I've misunderstood anything. Thanks a lot.

loribonna commented 3 weeks ago

The question is a bit of a rabbit hole and I'm not an expert on this but the reason is because the Fisher information matrix is computed as the expectation over the model's prediction of the gradients squared, so you need to multiply them by p(y|x), which is the why we take the exp of the loss.

I suggest you check out this paper and this discussion for more info.

Edit: in your regression scenario while you could use the same code from EwC I don't think the math would check out.