from skompiler import skompile
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingRegressor
X, y = load_iris(return_X_y=True)
regr = GradientBoostingRegressor(random_state=1, max_depth=3, n_estimators=3)
res = regr.fit(X, y)
expr = skompile(res.predict)
rust = expr.to('sympy/cxx')
print(rust)
Which gives this error
envs/komp/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: Attribute `n_features_` was deprecated in version 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
warnings.warn(msg, category=FutureWarning)
Traceback (most recent call last):
File "test.py", line 11, in <module>
rust = expr.to('sympy/cxx')
File "SKompiler/skompiler/ast.py", line 268, in to
return translator(self, *dialect, *args, **kw)
File "SKompiler/skompiler/fromskast/sympy.py", line 67, in translate
return to_code(syexpr, dialect, assign_to=assign_to, **kw)
File "SKompiler/skompiler/fromskast/sympy.py", line 236, in to_code
return _code_printers[dialect](syexpr, assign_to=assign_to, **kw)
File "SKompiler/skompiler/fromskast/sympy.py", line 204, in <lambda>
'cxx': lambda expr, **kw: sp.cxxcode(expr, user_functions=_ufns, **kw),
File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 865, in cxxcode
return cxx_code_printers[standard.lower()](settings).doprint(expr, assign_to)
File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 150, in doprint
lines = self._print(expr).splitlines()
File "envs/komp/lib/python3.9/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 375, in _print_Assignment
return self._doprint_loops(rhs, lhs)
File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 184, in _doprint_loops
dummies = get_contraction_structure(expr)
File "envs/komp/lib/python3.9/site-packages/sympy/tensor/index_methods.py", line 443, in get_contraction_structure
result[key] |= d[key]
TypeError: unsupported operand type(s) for |=: 'set' and 'Piecewise'
Hi, I tried the following code with Sympy 1.10
Which gives this error