konstantint / SKompiler

A tool for compiling trained SKLearn models into other representations (such as SQL, Sympy or Excel formulas)
MIT License
171 stars 10 forks source link

Translation to another language is broken in Sympy 1.10 #12

Open darleybarreto opened 2 years ago

darleybarreto commented 2 years ago

Hi, I tried the following code with Sympy 1.10

from skompiler import skompile
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingRegressor

X, y = load_iris(return_X_y=True)

regr = GradientBoostingRegressor(random_state=1, max_depth=3, n_estimators=3)
res = regr.fit(X, y)
expr = skompile(res.predict)
rust = expr.to('sympy/cxx')
print(rust)

Which gives this error

envs/komp/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: Attribute `n_features_` was deprecated in version 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
  warnings.warn(msg, category=FutureWarning)
Traceback (most recent call last):
  File "test.py", line 11, in <module>
    rust = expr.to('sympy/cxx')
  File "SKompiler/skompiler/ast.py", line 268, in to
    return translator(self, *dialect, *args, **kw)
  File "SKompiler/skompiler/fromskast/sympy.py", line 67, in translate
    return to_code(syexpr, dialect, assign_to=assign_to, **kw)
  File "SKompiler/skompiler/fromskast/sympy.py", line 236, in to_code
    return _code_printers[dialect](syexpr, assign_to=assign_to, **kw)
  File "SKompiler/skompiler/fromskast/sympy.py", line 204, in <lambda>
    'cxx': lambda expr, **kw: sp.cxxcode(expr, user_functions=_ufns, **kw),
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 865, in cxxcode
    return cxx_code_printers[standard.lower()](settings).doprint(expr, assign_to)
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 150, in doprint
    lines = self._print(expr).splitlines()
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 375, in _print_Assignment
    return self._doprint_loops(rhs, lhs)
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 184, in _doprint_loops
    dummies = get_contraction_structure(expr)
  File "envs/komp/lib/python3.9/site-packages/sympy/tensor/index_methods.py", line 443, in get_contraction_structure
    result[key] |= d[key]
TypeError: unsupported operand type(s) for |=: 'set' and 'Piecewise'