I just wanted to suggest the following feature, which I've implemented for myself:
The export_torch.SingleSymPyModule object only holds the initial expression string at the time of export from a PySR model (Sympy expression).
I'm considering the use case in which the PyTorch model would be further trained to tweak parameters, and it would be very helpful to then inspect the resulting expression, with updated parameter values. However, there's currently no method to update the string.
If this is interesting (either here or in the base sympytorch project @patrick-kidger ), I'm pasting my implementation below, which returns a Sympy expression that represents the current SingleSymPyModule (as an independent function, but could be a proper class __repr__).
It's based on reverse engineering the recursive Node structure so I'm not entirely sure of my parsing (specifically, the conditions for deciding the type of node). The way I'm mapping back function names to sympy.core also feels a bit too broad (using all public content of that module), but working under a time budget :(
I did however test with property-based testing (using hypothesis), making sure the round-trip PySR -> SingleSympyModule -> SymPy expression always agreed on output with the initial PySR model, given random choices of input operations for PySRRegressor. Seems to work fine. Pasting that below as well.
The only thing that's not guaranteed to match is the order of terms. That's fine for my work, plus I'm not sure how much time that would take, but would definitely help to have that as well.
Suggestion
def sympytorch_expr(model: export_torch.SingleSymPyModule) -> sympy.Expr:
"""
Retrieve the Sympy expression of a SingleSymPyModule.
Relies on mapping of Sympy operations in sympy.core.__dict__,
e.g. {'Mul': sympy.core.mul.Mul}
:param model: SingleSymPyModule instance
:return: Sympy expression
"""
str_repr = _sympytorch_node_repr(model._node)
sympy_op_mapping = sympy.core.__dict__
# A modicum of sanitizing
sympy_op_mapping = {op_name: op for op_name, op in sympy_op_mapping.items()
if not op_name.startswith('_')}
return parse_expr(str_repr, local_dict=sympy_op_mapping)
def _sympytorch_node_repr(node) -> str:
if _sympytorch_node_is_variable(node):
return node._name
if _sympytorch_node_is_function(node):
return str(node)
if _sympytorch_node_is_parameter(node):
if isinstance(node._value, torch.nn.Parameter):
return str(node._value.data.item())
return str(node._value)
else:
# Remove the qualifier from class name for later parsing from sympy.core
# e.g. sympy.core.mul.Mul -> Mul
func_repr = str(node._sympy_func).split('.')[-1].split("'")[0]
args_repr = [_sympytorch_node_repr(arg) for arg in node._args]
args_repr = '(' + ', '.join(args_repr) + ')'
args_repr = func_repr + args_repr
return args_repr
def _sympytorch_node_is_variable(node) -> bool:
return hasattr(node, '_name')
def _sympytorch_node_is_function(node) -> bool:
return issubclass(type(node), sympy.core.function.FunctionClass)
def _sympytorch_node_is_parameter(node) -> bool:
return not hasattr(node, '_args') or not node._args
Property-based test
import numpy as np
from pysr import PySRRegressor
from sympy import lambdify
from hypothesis import given, seed, settings, HealthCheck
import hypothesis.strategies as strat
import pytest
@pytest.fixture
def data_for_test_sympytorch_repr():
"""Cache the data generation to save a little time per example."""
rng = np.random.default_rng(42)
X = rng.uniform(low=0, high=2, size=(10, 1))
X_test = rng.uniform(low=2, high=4, size=(10, 1))
y = 2 * np.cos(X) + X ** 2 - 2
return X, X_test, y
@given(
binary_operators=strat.sets(strat.sampled_from(['+', '*', '/']),
min_size=1, max_size=2),
unary_operators=strat.sets(strat.sampled_from(['sin', 'log', 'sqrt', 'square']),
min_size=1, max_size=2),
)
@settings(max_examples=100,
deadline=10000, # Account for variation between example times
suppress_health_check=[HealthCheck.function_scoped_fixture])
@seed(42)
def test_sympytorch_repr(binary_operators, unary_operators, data_for_test_sympytorch_repr):
"""
Test invariant = outputs match after round-trip
PySR model (expression) -> SingleSympyModule -> SymPy expression
under random PySR input operators.
"""
X, X_test, y = data_for_test_sympytorch_repr
model = PySRRegressor(binary_operators=list(binary_operators),
unary_operators=list(unary_operators),
niterations=2,
deterministic=True,
procs=0,
random_state=42,
temp_equation_file=True,
verbosity=0)
model.fit(X, y)
torch_model = model.pytorch()
torch_model_expr = sympytorch_expr(torch_model)
torch_model_expr_func = lambdify('x0', torch_model_expr, 'numpy')
output_expr = torch_model_expr_func(X_test).ravel()
outupt_pysr = model.predict(X_test).ravel()
assert np.allclose(output_expr, outupt_pysr)
Feature Request
Hi!
Very cool project!
I just wanted to suggest the following feature, which I've implemented for myself:
The
export_torch.SingleSymPyModule
object only holds the initial expression string at the time of export from a PySR model (Sympy expression).I'm considering the use case in which the PyTorch model would be further trained to tweak parameters, and it would be very helpful to then inspect the resulting expression, with updated parameter values. However, there's currently no method to update the string.
If this is interesting (either here or in the base sympytorch project @patrick-kidger ), I'm pasting my implementation below, which returns a Sympy expression that represents the current
SingleSymPyModule
(as an independent function, but could be a proper class__repr__
).It's based on reverse engineering the recursive
Node
structure so I'm not entirely sure of my parsing (specifically, the conditions for deciding the type of node). The way I'm mapping back function names tosympy.core
also feels a bit too broad (using all public content of that module), but working under a time budget :(I did however test with property-based testing (using hypothesis), making sure the round-trip PySR -> SingleSympyModule -> SymPy expression always agreed on output with the initial PySR model, given random choices of input operations for
PySRRegressor
. Seems to work fine. Pasting that below as well.The only thing that's not guaranteed to match is the order of terms. That's fine for my work, plus I'm not sure how much time that would take, but would definitely help to have that as well.
Suggestion
Property-based test