Open francescamanni1989 opened 3 years ago
Thanks for reporting! @francescamanni1989
Could you provide the code snippet that reproduces the runtime error.
Hi,
the code part in gradientboosting.py is in the argsvar part, when boosting is performed: def forward(self, x): output = [estimator(x) for estimator in self.estimators] output = op.sum_with_multiplicative(output, self.shrinkage_rate) output = F.softmax(output, dim=1) return output
My error comes, when trying to script the model:
model = model_ensemble traced_model = torch.jit.script(model)
where model_ensemble could be:
model_ensemble = GradientBoostingClassifier( estimator=MLP, n_estimators=10, cuda=False, shrinkage_rate=0.9, )
It looks like the package does not support torchscript
well for now. I will have a careful look when I get a moment, thanks!
Exactly! Thank you
Also, the function sum is not scriptable, but this could be by-passed using @torch.jit.ignore()
My suggestion for the indexed variable is to use a for loop instead.
Hu everyone,
I am trying to script the ensemble, however, argsvar cannot be used with torchscript
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File ".....\lib\site-packages\torchensemble\soft_gradient_boosting.py", line 390 "classifierforward", ) def forward(self, x): ~~ <--- HERE output = [estimator(x) for estimator in self.estimators] output = op.sum_with_multiplicative(output, self.shrinkage_rate)
do you have any idea on how to handle it?