Translation to another language is broken in Sympy 1.10 #12

Open darleybarreto opened 2 years ago

darleybarreto commented 2 years ago

Hi, I tried the following code with Sympy 1.10

from skompiler import skompile
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingRegressor

X, y = load_iris(return_X_y=True)

regr = GradientBoostingRegressor(random_state=1, max_depth=3, n_estimators=3)
res = regr.fit(X, y)
expr = skompile(res.predict)
rust = expr.to('sympy/cxx')

Which gives this error

envs/komp/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: Attribute `n_features_` was deprecated in version 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
  warnings.warn(msg, category=FutureWarning)
Traceback (most recent call last):
  File "test.py", line 11, in <module>
    rust = expr.to('sympy/cxx')
  File "SKompiler/skompiler/ast.py", line 268, in to
    return translator(self, *dialect, *args, **kw)
  File "SKompiler/skompiler/fromskast/sympy.py", line 67, in translate
    return to_code(syexpr, dialect, assign_to=assign_to, **kw)
  File "SKompiler/skompiler/fromskast/sympy.py", line 236, in to_code
    return _code_printers[dialect](syexpr, assign_to=assign_to, **kw)
  File "SKompiler/skompiler/fromskast/sympy.py", line 204, in <lambda>
    'cxx': lambda expr, **kw: sp.cxxcode(expr, user_functions=_ufns, **kw),
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 865, in cxxcode
    return cxx_code_printers[standard.lower()](settings).doprint(expr, assign_to)
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 150, in doprint
    lines = self._print(expr).splitlines()
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 375, in _print_Assignment
    return self._doprint_loops(rhs, lhs)
  File "envs/komp/lib/python3.9/site-packages/sympy/printing/codeprinter.py", line 184, in _doprint_loops
    dummies = get_contraction_structure(expr)
  File "envs/komp/lib/python3.9/site-packages/sympy/tensor/index_methods.py", line 443, in get_contraction_structure
    result[key] |= d[key]
TypeError: unsupported operand type(s) for |=: 'set' and 'Piecewise'