TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

Getting error while using GradientBoostingClassifier #146

Closed singhabhinav closed 1 year ago

singhabhinav commented 1 year ago

Hi,

I am getting below error while using GradientBoostingClassifier. Could you please let me know how can we fix this Screenshot 2023-03-01 at 3 33 44 PM

xuyxu commented 1 year ago

Hi @singhabhinav, what is the number of classes in your data loader, and what is the network structure of your base estimator ?

singhabhinav commented 1 year ago

Hi @xuyxu

I have 60 classes in my output and please find the attached network structure and the batch size in my data loader is 128

Screenshot 2023-03-04 at 2 50 43 PM

xuyxu commented 1 year ago

Thanks @singhabhinav, another thing to confirm is that the data batch fed to CIF is of the shape (batch_size, 10), right?

singhabhinav commented 1 year ago

@xuyxu yes I can confirm this.

Same code is working fine with VotingClassifier but throws the above error with GradientBoostingClassifier. I have printed the shape of input tensor and it is indeed (batch_size, 10)

Screenshot 2023-03-05 at 1 18 12 PM

singhabhinav commented 1 year ago

@xuyxu

I was able to run it after installing torchensemble from github

The version in pip throws the above error ..so we can ignore it now I believe