TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

How to script the forward pass? #96

Open francescamanni1989 opened 3 years ago

francescamanni1989 commented 3 years ago

Hu everyone,

I am trying to script the ensemble, however, argsvar cannot be used with torchscript

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File ".....\lib\site-packages\torchensemble\soft_gradient_boosting.py", line 390 "classifierforward", ) def forward(self, x): ~~ <--- HERE output = [estimator(x) for estimator in self.estimators] output = op.sum_with_multiplicative(output, self.shrinkage_rate)

do you have any idea on how to handle it?

xuyxu commented 3 years ago

Thanks for reporting! @francescamanni1989

Could you provide the code snippet that reproduces the runtime error.

francescamanni1989 commented 3 years ago

Hi,

the code part in gradientboosting.py is in the argsvar part, when boosting is performed: def forward(self, x): output = [estimator(x) for estimator in self.estimators] output = op.sum_with_multiplicative(output, self.shrinkage_rate) output = F.softmax(output, dim=1) return output

My error comes, when trying to script the model:

model = model_ensemble traced_model = torch.jit.script(model)

where model_ensemble could be:

model_ensemble = GradientBoostingClassifier( estimator=MLP, n_estimators=10, cuda=False, shrinkage_rate=0.9, )

xuyxu commented 3 years ago

It looks like the package does not support torchscript well for now. I will have a careful look when I get a moment, thanks!

francescamanni1989 commented 3 years ago

Exactly! Thank you

francescamanni1989 commented 3 years ago

Also, the function sum is not scriptable, but this could be by-passed using @torch.jit.ignore()

francescamanni1989 commented 3 years ago

My suggestion for the indexed variable is to use a for loop instead.