MilesCranmer / PySR

High-Performance Symbolic Regression in Python and Julia
https://astroautomata.com/PySR
Apache License 2.0
2.32k stars 211 forks source link

[BUG]: torch export fails for expressions with constant inputs e.g. exp(2) #656

Open tbuckworth opened 3 months ago

tbuckworth commented 3 months ago

What happened?

sympy2torch produces a module that fails when called if a function of a constant is present in the expression.

For example:

from sympy import symbols, exp
from pysr import sympy2torch
import torch

x, y = symbols("x y")

expression = exp(2)

module = sympy2torch(expression, [x, y])

X = torch.rand(100, 2).float() * 10

torch_out = module(X)

produces this error

TypeError: exp(): argument 'input' (position 1) must be Tensor, not float

I've tried other expressions like log(4), which produces the same problem.

The current mapping in export_torch.py is sympy.exp: torch.exp.

I believe that

def exp(x):
    return torch.exp(torch.FloatTensor(x))

then using the mapping sympy.exp: exp might work, but I have been unable to test it (adding to extra_sympy_mappings doesn't work, I think because it is chained to the end of the existing mappings and doesn't override the original one).

Alternatively, perhaps simplifying all expressions to constants where possible might solve the problem for all expressions e.g. exp(2) becomes 7.38905609893.

Version

0.18.4

Operating System

Linux

Package Manager

pip

Interface

Script (i.e., python my_script.py)

Relevant log output

No response

Extra Info

No response

MilesCranmer commented 3 months ago

Thanks for the bug report! One thing I am surprised about is that I thought this would already happen? See this code: https://github.com/MilesCranmer/PySR/blob/06ca0e376e63d563aa063028a5f9bc7fa7d849c5/pysr/export_torch.py#L94-L97

The code torch.tensor(float(expr)) should already map any constant into a torch tensor.

Maybe the issue is that you are explicitly passing a Python integer, rather than a SymPy integer?

For example, we can see these are actually different classes:

In [4]: isinstance(1, sympy.Integer)
Out[4]: False

Did you see this error from a PySR export, or are you trying to use sympy2torch manually and putting in the integers explicitly?

tbuckworth commented 3 months ago

This came about using PySRRegressor.fit(), which produced an expression containing square(exp(sign(0.44796443))), which seems to simplify to exp(2).

I recreated the issue using expression = exp(sign(0.44796443))*exp(sign(0.44796443)) originally, but wrote exp(2) here as a minimal example.

MilesCranmer commented 3 months ago

Do you know the original error message? It could be the MWE is actually a different thing. The exp(2) should never actually occur, it should (I think) be exp(sympy.Integer(2)). At least it should be.

perhaps because it was sign(..) it is some kind of floating point number PySR don’t account for

tbuckworth commented 3 months ago

Oh, apologies if this is my fault, but I was using extra_torch_mappings that included:

sympy.core.numbers.Exp1: exp1

where

def exp1():
    return torch.exp(torch.FloatTensor([1]))

I believe I added this due to an error arising when trying to export to torch an expression containing exp(sign(0.1...)), but I don't remember exactly.

In terms of the original error in this issue, the PySRRegressor.fit function learned this expression:

(square(x2 / 0.10893087) * exp(x3)) - square(exp(sign(0.44796443)))

I then called model.pytorch(), which resulted in this error:

> 22 Traceback (most recent call last):
> 23   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 151, in forward
> 24     arg_ = memodict[arg]
> 25 KeyError: _Node(
> 26   (_args): ModuleList(

> 27     (0): _Node()
> 28     (1): _Node(
> 29       (_args): ModuleList(
> 30         (0): _Node()
> 31       )
> 32     )
> 33   )
> 34 )
> 35 During handling of the above exception, another exception occurred:
> 36 Traceback (most recent call last):
> 37   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 151, in forward
> 38     arg_ = memodict[arg]
> 39 KeyError: _Node(
> 40   (_args): ModuleList(
> 41     (0): _Node()
> 42   )
> 43 )
> 44 During handling of the above exception, another exception occurred:
> 45 Traceback (most recent call last):
> 46   File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 300, in <module>
> 47     optimize_hyperparams(bounds, fixed, project, id_tag, run_graph_hyperparameters)
> 48   File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 141, in optimize_hyperparams
> 49     run_next(hparams)
> 50   File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 116, in run_graph_hyperparameters
> 51     run_graph_neurosymbolic_search(args)
> 52   File "/vol/bitbucket/tfb115/train-procgen-pytorch/graph_sr.py", line 503, in run_graph_neurosymbolic_search

> 53     fine_tuned_policy = fine_tune(ns_agent.policy, logdir, symbdir, hp_override)
> 54   File "/vol/bitbucket/tfb115/train-procgen-pytorch/graph_sr.py", line 397, in fine_tune
> 55     agent.train(args.num_timesteps)
> 56   File "/vol/bitbucket/tfb115/train-procgen-pytorch/agents/ppo_model.py", line 213, in train
> 57     act, value = self.predict(obs)
> 58   File "/vol/bitbucket/tfb115/train-procgen-pytorch/agents/ppo_model.py", line 107, in predict
> 59     dist, value = self.policy(obs)
> 60   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 61     return self._call_impl(*args, **kwargs)
> 62   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 63     return forward_call(*args, **kwargs)
> 64   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 124, in forward
> 65     d, r = self.all_dones_rewards(s)
> 66   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 156, in all_dones_rewards
> 67     dones, rew = self.dr(sa)
> 68   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 102, in dr
> 69     d = self.done_model(sa)
> 70   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 71     return self._call_impl(*args, **kwargs)
> 72   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 73     return forward_call(*args, **kwargs)
> 74   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/model.py", line 1065, in forward
> 75     return self.fwd(X)
> 76   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/model.py", line 1061, in fwd
> 77     return self.model._node(symbols)
> 78   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 79     return self._call_impl(*args, **kwargs)
> 80   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 81     return forward_call(*args, **kwargs)
> 82   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 153, in forward
> 83     arg_ = arg(memodict)
> 84   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 85     return self._call_impl(*args, **kwargs)
> 86   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 87     return forward_call(*args, **kwargs)
> 88   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 153, in forward
> 89     arg_ = arg(memodict)
> 90   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 91     return self._call_impl(*args, **kwargs)
> 92   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 93     return forward_call(*args, **kwargs)
> 94   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 156, in forward
> 95     return self._torch_func(*args)
> 96 TypeError: exp(): argument 'input' (position 1) must be Tensor, not float
MilesCranmer commented 3 months ago

It might be because your function exp1 is returning torch.exp(torch.FloatTensor([1])) rather than torch.exp(torch.tensor(1))? (note the difference in shape)

- torch.exp(torch.FloatTensor([1]))
+ torch.exp(torch.tensor(1))

But normally I would just do

extra_torch_mappings={sympy.core.numbers.Exp1: (lambda: math.exp(1.0))}

which is similar to the definitions for sympy.core.numbers.Half and sympy.core.numbers.One.

MilesCranmer commented 3 months ago

Ok, weird, I can actually reproduce it with the following, which sounds to be the same as you saw originally:

import math
import pysr
import sympy
import torch

ex = pysr.export_sympy.pysr2sympy(
    "square(exp(sign(0.44796443))) + 1.5 * x1",
    feature_names_in=["x1"],
    extra_sympy_mappings={"square": lambda x: x**2},
)

def exp1():
    return torch.exp(torch.FloatTensor([1]))

m = pysr.export_torch.sympy2torch(
    ex, ["x1"], extra_torch_mappings={sympy.core.numbers.Exp1: exp1}
)
m(torch.randn(10, 1))  # Errors

m2 = pysr.export_torch.sympy2torch(
    ex, ["x1"], extra_torch_mappings={sympy.core.numbers.Exp1: (lambda: math.exp(1))}
)
m2(torch.randn(10, 1))  # Also errors
MilesCranmer commented 3 months ago

Ah, I got it! It's because we don't have a branch for sympy.core.numbers.NumberSymbol. Argh...

https://github.com/MilesCranmer/PySR/blob/06ca0e376e63d563aa063028a5f9bc7fa7d849c5/pysr/export_torch.py#L94-L122

Will also need to get added to the sympy2jax code I guess.

tbuckworth commented 5 days ago

I see this was fixed in version 0.19.0

However, the issue still arises now and then for me, with the function sin.

I can recreate the issue with the following code:

from sympy import symbols, sin, sign
from pysr import sympy2torch
import torch

x, y = symbols("x y")

expression = sin(sign(-0.041662704))

module = sympy2torch(expression, [x, y])

X = torch.rand(100, 2).float() * 10

torch_out = module(X)

TypeError: sin(): argument 'input' (position 1) must be Tensor, not float

MilesCranmer commented 5 days ago

Thanks for making a MWE. I’ll take a look. It seems like if you run sympy2torch directly on a float, that causes the issue?

tbuckworth commented 5 days ago

If I run it directly on a float I get a different error:

AttributeError: 'float' object has no attribute 'func'

tbuckworth commented 5 days ago

if you remove sign from the original expression, there's no error, but if you replace the expression with sin(-1) it throws the same error.

I would guess it's still to do with sympy.core.numbers.NumberSymbol?

tbuckworth commented 5 days ago

I've proposed a change in #726 to this code:

https://github.com/MilesCranmer/PySR/blob/339cc0a96be6cb0c41daa6c0ffa3a76cb1ecc9e4/pysr/export_torch.py#L117-L121

Is that feasible? or do you think it would break other behaviour?