TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

only one model in FusionClassifier ensemble is learning anything #94

Closed kiranchari closed 3 years ago

kiranchari commented 3 years ago

I am using a FusionClassifier of 3 estimators with a ResNet backbone. After training the ensemble from scratch on CIFAR10, when I evaluate the individual estimators in the ensemble, I find that only one of the estimators has learned anything and the others are performing no better than random. So it appears the ensemble did not train properly for some reason.

Have you come across this issue, any tips to debug this?

xuyxu commented 3 years ago

My suggestion is to give VotingClassifier a try, FusionClassifer is not suited for base models like ResNet ;-)

kiranchari commented 3 years ago

I actually need to train the models jointly for some additional regularisation I am doing. Which models is FusionClassifier suitable for?

xuyxu commented 3 years ago

Which kind of regularization are you going to put into the model? Is it hard-coded into the forword process of the base estimator, or it is simply another term in the loss function?

kiranchari commented 3 years ago

It's an another loss term that depends on all the estimators.

xuyxu commented 3 years ago

It's an another loss term that depends on all the estimators.

In this case, the VotingClassifier is not applicable. So, does your issue occur after adding the regularization term into the loss function of FusionClassifier ?

kiranchari commented 3 years ago

Actually I trained a baseline ensemble without additional regularization and I still find that only one model has learned anything. The others are performing no better than random.

kiranchari commented 3 years ago

I was able to resolve this by averaging the softmax outputs of the estimators instead of their logits during training/test time.

xuyxu commented 3 years ago

Glad to see your problem soved, is there any problem with the code of FusionClassifier ?

kiranchari commented 3 years ago

In the FusionClassifier the output logits of the models are averaged but I think averaging softmax probabilities of the models works better (and is also standard practice I think).

https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/torchensemble/fusion.py#L33