Open tbuckworth opened 3 months ago
Thanks for the bug report! One thing I am surprised about is that I thought this would already happen? See this code: https://github.com/MilesCranmer/PySR/blob/06ca0e376e63d563aa063028a5f9bc7fa7d849c5/pysr/export_torch.py#L94-L97
The code torch.tensor(float(expr))
should already map any constant into a torch tensor.
Maybe the issue is that you are explicitly passing a Python integer, rather than a SymPy integer?
For example, we can see these are actually different classes:
In [4]: isinstance(1, sympy.Integer)
Out[4]: False
Did you see this error from a PySR export, or are you trying to use sympy2torch
manually and putting in the integers explicitly?
This came about using PySRRegressor.fit()
, which produced an expression containing square(exp(sign(0.44796443)))
, which seems to simplify to exp(2)
.
I recreated the issue using expression = exp(sign(0.44796443))*exp(sign(0.44796443))
originally, but wrote exp(2)
here as a minimal example.
Do you know the original error message? It could be the MWE is actually a different thing. The exp(2)
should never actually occur, it should (I think) be exp(sympy.Integer(2))
. At least it should be.
perhaps because it was sign(..)
it is some kind of floating point number PySR don’t account for
Oh, apologies if this is my fault, but I was using extra_torch_mappings
that included:
sympy.core.numbers.Exp1: exp1
where
def exp1():
return torch.exp(torch.FloatTensor([1]))
I believe I added this due to an error arising when trying to export to torch an expression containing exp(sign(0.1...))
, but I don't remember exactly.
In terms of the original error in this issue, the PySRRegressor.fit
function learned this expression:
(square(x2 / 0.10893087) * exp(x3)) - square(exp(sign(0.44796443)))
I then called model.pytorch()
, which resulted in this error:
> 22 Traceback (most recent call last):
> 23 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 151, in forward
> 24 arg_ = memodict[arg]
> 25 KeyError: _Node(
> 26 (_args): ModuleList(
> 27 (0): _Node()
> 28 (1): _Node(
> 29 (_args): ModuleList(
> 30 (0): _Node()
> 31 )
> 32 )
> 33 )
> 34 )
> 35 During handling of the above exception, another exception occurred:
> 36 Traceback (most recent call last):
> 37 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 151, in forward
> 38 arg_ = memodict[arg]
> 39 KeyError: _Node(
> 40 (_args): ModuleList(
> 41 (0): _Node()
> 42 )
> 43 )
> 44 During handling of the above exception, another exception occurred:
> 45 Traceback (most recent call last):
> 46 File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 300, in <module>
> 47 optimize_hyperparams(bounds, fixed, project, id_tag, run_graph_hyperparameters)
> 48 File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 141, in optimize_hyperparams
> 49 run_next(hparams)
> 50 File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 116, in run_graph_hyperparameters
> 51 run_graph_neurosymbolic_search(args)
> 52 File "/vol/bitbucket/tfb115/train-procgen-pytorch/graph_sr.py", line 503, in run_graph_neurosymbolic_search
> 53 fine_tuned_policy = fine_tune(ns_agent.policy, logdir, symbdir, hp_override)
> 54 File "/vol/bitbucket/tfb115/train-procgen-pytorch/graph_sr.py", line 397, in fine_tune
> 55 agent.train(args.num_timesteps)
> 56 File "/vol/bitbucket/tfb115/train-procgen-pytorch/agents/ppo_model.py", line 213, in train
> 57 act, value = self.predict(obs)
> 58 File "/vol/bitbucket/tfb115/train-procgen-pytorch/agents/ppo_model.py", line 107, in predict
> 59 dist, value = self.policy(obs)
> 60 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 61 return self._call_impl(*args, **kwargs)
> 62 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 63 return forward_call(*args, **kwargs)
> 64 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 124, in forward
> 65 d, r = self.all_dones_rewards(s)
> 66 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 156, in all_dones_rewards
> 67 dones, rew = self.dr(sa)
> 68 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 102, in dr
> 69 d = self.done_model(sa)
> 70 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 71 return self._call_impl(*args, **kwargs)
> 72 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 73 return forward_call(*args, **kwargs)
> 74 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/model.py", line 1065, in forward
> 75 return self.fwd(X)
> 76 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/model.py", line 1061, in fwd
> 77 return self.model._node(symbols)
> 78 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 79 return self._call_impl(*args, **kwargs)
> 80 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 81 return forward_call(*args, **kwargs)
> 82 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 153, in forward
> 83 arg_ = arg(memodict)
> 84 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 85 return self._call_impl(*args, **kwargs)
> 86 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 87 return forward_call(*args, **kwargs)
> 88 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 153, in forward
> 89 arg_ = arg(memodict)
> 90 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 91 return self._call_impl(*args, **kwargs)
> 92 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 93 return forward_call(*args, **kwargs)
> 94 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 156, in forward
> 95 return self._torch_func(*args)
> 96 TypeError: exp(): argument 'input' (position 1) must be Tensor, not float
It might be because your function exp1
is returning torch.exp(torch.FloatTensor([1]))
rather than torch.exp(torch.tensor(1))
? (note the difference in shape)
- torch.exp(torch.FloatTensor([1]))
+ torch.exp(torch.tensor(1))
But normally I would just do
extra_torch_mappings={sympy.core.numbers.Exp1: (lambda: math.exp(1.0))}
which is similar to the definitions for sympy.core.numbers.Half
and sympy.core.numbers.One
.
Ok, weird, I can actually reproduce it with the following, which sounds to be the same as you saw originally:
import math
import pysr
import sympy
import torch
ex = pysr.export_sympy.pysr2sympy(
"square(exp(sign(0.44796443))) + 1.5 * x1",
feature_names_in=["x1"],
extra_sympy_mappings={"square": lambda x: x**2},
)
def exp1():
return torch.exp(torch.FloatTensor([1]))
m = pysr.export_torch.sympy2torch(
ex, ["x1"], extra_torch_mappings={sympy.core.numbers.Exp1: exp1}
)
m(torch.randn(10, 1)) # Errors
m2 = pysr.export_torch.sympy2torch(
ex, ["x1"], extra_torch_mappings={sympy.core.numbers.Exp1: (lambda: math.exp(1))}
)
m2(torch.randn(10, 1)) # Also errors
Ah, I got it! It's because we don't have a branch for sympy.core.numbers.NumberSymbol
. Argh...
Will also need to get added to the sympy2jax code I guess.
I see this was fixed in version 0.19.0
However, the issue still arises now and then for me, with the function sin
.
I can recreate the issue with the following code:
from sympy import symbols, sin, sign
from pysr import sympy2torch
import torch
x, y = symbols("x y")
expression = sin(sign(-0.041662704))
module = sympy2torch(expression, [x, y])
X = torch.rand(100, 2).float() * 10
torch_out = module(X)
TypeError: sin(): argument 'input' (position 1) must be Tensor, not float
Thanks for making a MWE. I’ll take a look. It seems like if you run sympy2torch
directly on a float, that causes the issue?
If I run it directly on a float I get a different error:
AttributeError: 'float' object has no attribute 'func'
if you remove sign
from the original expression, there's no error, but if you replace the expression with sin(-1)
it throws the same error.
I would guess it's still to do with sympy.core.numbers.NumberSymbol
?
I've proposed a change in #726 to this code:
Is that feasible? or do you think it would break other behaviour?
What happened?
sympy2torch produces a module that fails when called if a function of a constant is present in the expression.
For example:
produces this error
I've tried other expressions like log(4), which produces the same problem.
The current mapping in
export_torch.py
issympy.exp: torch.exp
.I believe that
then using the mapping
sympy.exp: exp
might work, but I have been unable to test it (adding to extra_sympy_mappings doesn't work, I think because it is chained to the end of the existing mappings and doesn't override the original one).Alternatively, perhaps simplifying all expressions to constants where possible might solve the problem for all expressions e.g.
exp(2)
becomes7.38905609893
.Version
0.18.4
Operating System
Linux
Package Manager
pip
Interface
Script (i.e.,
python my_script.py
)Relevant log output
No response
Extra Info
No response