TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

BUG: overwriting estimators_ #54

Closed Xiaohui9607 closed 3 years ago

Xiaohui9607 commented 3 years ago

` with Parallel(n_jobs=self.n_jobs) as parallel:

    ...
            if test_loader:
                ...

                    if mse < best_mse:
                        best_mse = mse
                        self.estimators_ = nn.ModuleList()   <---
                        self.estimators_.extend(estimators)
                        if save_model:
                            io.save(self, save_dir, self.logger)

                    msg = ("Epoch: {:03d} | Validation MSE:"
                           " {:.5f} | Historical Best: {:.5f}")
                    self.logger.info(msg.format(epoch, mse, best_mse))

            # Update the scheduler
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", UserWarning)

                if self.use_scheduler_:
                    scheduler_.step()

    self.estimators_ = nn.ModuleList()   <---
    self.estimators_.extend(estimators)
    if save_model and not test_loader:
        io.save(self, save_dir, self.logger)

`

xuyxu commented 3 years ago

Hi @Xiaohui9607, are you referring to the case that the ensemble aftering calling fit actually is not the model with the best validation performance?

Xiaohui9607 commented 3 years ago

Yes, nevermind I figure it out

xuyxu commented 3 years ago

Thanks for reporting anyway.