TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.05k stars 95 forks source link

Adversarial Training Regressor and Classifiers have incorrect `is_classification` values set #130

Closed SarthakJariwala closed 1 year ago

SarthakJariwala commented 1 year ago

AdversarialTrainingRegressor has is_classification set to True in _parallel_fit_per_epoch

https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/d887b0089fdab31cc30280a75481cddc54ade634/torchensemble/adversarial_training.py#L411

https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/d887b0089fdab31cc30280a75481cddc54ade634/torchensemble/adversarial_training.py#L510-L523


AdversarialTrainingClassifier has is_classification set to False in _parallel_fit_per_epoch https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/d887b0089fdab31cc30280a75481cddc54ade634/torchensemble/adversarial_training.py#L221

https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/d887b0089fdab31cc30280a75481cddc54ade634/torchensemble/adversarial_training.py#L324-L337