heal-research / pyoperon

Python bindings and scikit-learn interface for the Operon library for symbolic regression.
MIT License
42 stars 12 forks source link

How to save the trained model #21

Open luokuang2001 opened 2 months ago

luokuang2001 commented 2 months ago

I have trained a symbolic regression model, it has the following form: image

How can I save the trained model, so that next time I can use the model directly without retraining it? Just like the "torch.save" in Pytorch can save a "pth" file

foolnotion commented 2 months ago

Hi, you have some options:

  1. Save the string and parse it back into a tree later
  2. Save the tree model directly using pickle

Here is some example code to illustrate the above. Pyoperon offers bindings to the Operon library which can enable this task.

import pyoperon as op # the operon bindings
import os
import re
import pickle

Load some data and get the actual variables (described by hashes):

dataset = op.Dataset('./data/Poly-10.csv', True)

variable_hashes = [v.Hash for v in dataset.Variables]
variable_hashes

Now let's say you have an infix expression in string form. First, we extract the variable names from the expressions (these will always be named $X_i$ since pyoperon doesn't need a header). We map every $X_i$ to the actual corresponding variable from the dataset

expr = '((-0.000030386898288270458579) + ((-0.400405615568161010742188) * (((2.294569730758666992187500 * X5) * ((-1.088928818702697753906250) * X6)) - (((((-0.937077045440673828125000) * X7) * (((-2.449583292007446289062500) * X9) * (0.209568977355957031250000 * X1))) + (((((-0.039391547441482543945312) * X4) * ((-0.039391547441482543945312) * X4)) * (((2.302585124969482421875000 * X2) * (1.030569076538085937500000 * X1)) - ((((-0.039391547441482543945312) * X4) * (1.036243557929992675781250 * X1)) + ((2.718281745910644531250000 * X3) * (1.777204394340515136718750 * X4))))) + ((((-0.929394364356994628906250) * X7) * (((-2.058582305908203125000000) * X9) * (1.052866220474243164062500 * X1))) + ((2.407963514328002929687500 * X2) * (1.036243557929992675781250 * X1))))) + ((1.414213538169860839843750 * X3) * ((1.770285606384277343750000 * X4) - (((-1.811097383499145507812500) * X10) * (0.974092006683349609375000 * X6))))))))'

m = re.findall(r'X\d+', expr)

variables = {}

for v in (v for v in m if v not in variables):
    i = int(v.split('X')[1])
    variables[v] = variable_hashes[i-1]

variables
{'X5': 16075665569742270374,
 'X6': 9134146818458426180,
 'X7': 18044635619207560834,
 'X9': 2652961248133790663,
 'X1': 4295753595843180382,
 'X4': 17733306235974623085,
 'X2': 18188060951833565637,
 'X3': 4397419642548150523,
 'X10': 17424446509373167524}

Now that we have obtained the correct variable mapping, we can parse the expression

tree = op.InfixParser.Parse(expr, variables)

# print it out again
decimal_precision = 3 # how many decimals to use when formatting floating point values
print(op.InfixFormatter.Format(tree, dataset, decimal_precision))
((-0.000) + ((-0.400) * (((2.295 * (1.000 * X5)) * ((-1.089) * (1.000 * X6))) - (((((-0.937) * (1.000 * X7)) * (((-2.450) * (1.000 * X9)) * (0.210 * (1.000 * X1)))) + (((((-0.039) * (1.000 * X4)) * ((-0.039) * (1.000 * X4))) * (((2.303 * (1.000 * X2)) * (1.031 * (1.000 * X1))) - ((((-0.039) * (1.000 * X4)) * (1.036 * (1.000 * X1))) + ((2.718 * (1.000 * X3)) * (1.777 * (1.000 * X4)))))) + ((((-0.929) * (1.000 * X7)) * (((-2.059) * (1.000 * X9)) * (1.053 * (1.000 * X1)))) + ((2.408 * (1.000 * X2)) * (1.036 * (1.000 * X1)))))) + ((1.414 * (1.000 * X3)) * ((1.770 * (1.000 * X4)) - (((-1.811) * (1.000 * X10)) * (0.974 * (1.000 * X6)))))))))

We can evaluate the parsed tree using op.Evaluate:

values = op.Evaluate(tree, dataset, op.Range(0, 10))
values
array([ 0.4543666 ,  0.27158856, -0.11406795, -0.4064015 , -0.10081271,
        0.17754017, -1.0105664 ,  0.4164615 ,  0.44278234,  0.0433833 ],
      dtype=float32)

The tree can be pickled:

path = os.path.join('pickled', 'tree.pkl')

with open(path, 'wb+') as f:
    pickle.dump(tree, f)

Then it can be loaded again:

with open(path, 'rb') as f:
    tree_unpickled = pickle.load(f)
    print(op.InfixFormatter.Format(tree_unpickled, dataset, decimal_precision))
((-0.000) + ((-0.400) * (((2.295 * (1.000 * X5)) * ((-1.089) * (1.000 * X6))) - (((((-0.937) * (1.000 * X7)) * (((-2.450) * (1.000 * X9)) * (0.210 * (1.000 * X1)))) + (((((-0.039) * (1.000 * X4)) * ((-0.039) * (1.000 * X4))) * (((2.303 * (1.000 * X2)) * (1.031 * (1.000 * X1))) - ((((-0.039) * (1.000 * X4)) * (1.036 * (1.000 * X1))) + ((2.718 * (1.000 * X3)) * (1.777 * (1.000 * X4)))))) + ((((-0.929) * (1.000 * X7)) * (((-2.059) * (1.000 * X9)) * (1.053 * (1.000 * X1)))) + ((2.408 * (1.000 * X2)) * (1.036 * (1.000 * X1)))))) + ((1.414 * (1.000 * X3)) * ((1.770 * (1.000 * X4)) - (((-1.811) * (1.000 * X10)) * (0.974 * (1.000 * X6)))))))))

Check that it returns the same values:

values = op.Evaluate(tree, dataset, op.Range(0, 10))
values
array([ 0.4543666 ,  0.27158856, -0.11406795, -0.4064015 , -0.10081271,
        0.17754017, -1.0105664 ,  0.4164615 ,  0.44278234,  0.0433833 ],
      dtype=float32)

If you don't want to use the bindings, one can simply parse and evaluate the expression string using SymPy:

from sympy import lambdify, parse_expr, Symbol
import pandas as pd

def evaluate_expression(sexpr, data):
    symbols = [Symbol(x) for x in data.columns.values[:-1]]
    return lambdify(symbols, parse_expr(sexpr))(*data.values[:,:-1].T)

# read the data as a pandas dataframe
df = pd.read_csv('./data/Poly-10.csv')

sexpr = expr.replace('^', '**')
values = evaluate_expression(sexpr, df.iloc[0:10])
values

Hope this helps, feel free to ask for details.