MilesCranmer / PySR

High-Performance Symbolic Regression in Python and Julia
https://astroautomata.com/PySR
Apache License 2.0
2.08k stars 198 forks source link

[BUG]: Piecewise not in torch_mappings #639

Open tbuckworth opened 1 month ago

tbuckworth commented 1 month ago

What happened?

after fitting a pysr module with "greater" as a binary operator, exporting to torch failed with the following error:

KeyError: 'Function Piecewise was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

I've seen that in #433 Piecewise was added to the mappings, so I'm surprised to see this error.

I did attempt to fix myself, but it didn't work out: I've tried adding mappings such as:

{sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0])}

but then the same error arises for sympy.functions.elementary.piecewise.ExprCondPair and then sympy.logic.boolalg.BooleanTrue

in the end, I added

extra_torch_mappings = {
        sympy.Piecewise: lambda x, y: torch.where(x[1], x[0], y[0]),
        sympy.functions.elementary.piecewise.ExprCondPair: tuple,
        sympy.logic.boolalg.BooleanTrue: torch.BoolTensor,
        "greater": lambda x, y: torch.where(x > y, 1.0, 0.0),
    }

But even this produced the following error:

KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

Hopefully, I am missing something obvious?

Version

0.18.4

Operating System

Linux

Package Manager

pip

Interface

Script (i.e., python my_script.py)

Relevant log output

No response

Extra Info

No response

tbuckworth commented 1 month ago

I just realised that #433 is a pull request, so I copied the code and used it to add the mappings manually. However, I'm still getting the error: KeyError: 'Function ITE was not found in Torch function mappings.Please add it to extra_torch_mappings in the format, e.g., {sympy.sqrt: torch.sqrt}.'

tbuckworth commented 1 month ago

I've added this mapping, which seems to circumvent the error, but I haven't fully tested it yet:

def if_then_else(*conds):
    a, b, c = conds
    return torch.where(a, torch.where(b, True, False), torch.where(c, True, False))

extra_torch_mappings = {sympy.logic.boolalg.ITE: if_then_else}

MilesCranmer commented 1 month ago

Nice! Yeah that should be added to the GitHub pull request. Feel free to suggest that on the PR via the review system and you will be credited as a coauthor of the PR.

tbuckworth commented 1 month ago

Thanks! I'll add a review comment on the PR.

There was another error with piecewise, when cond is a float (1.), but I fixed it by replacing cond with cond.bool():

output += torch.where(
                    cond.bool() & ~already_used, expr, torch.zeros_like(expr)
                )
                already_used = already_used | cond.bool()

Now, as long as I use a single batch dimension, it works, but multiple batch dimensions fail.

I believe this is due to export_torch.py, where _SingleSymPyModule.forward is:

            def forward(self, X):
                if self._selection is not None:
                    X = X[:, self._selection]
                symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
                return self._node(symbols)

if X[:, is replaced with X[..., then i believe it will work. This is a separate issue though, I suppose

MilesCranmer commented 4 weeks ago

(Just leaving it open until that PR is closed, since there are still some TODO items)