MilesCranmer / PySR

High-Performance Symbolic Regression in Python and Julia
https://astroautomata.com/PySR
Apache License 2.0
2.33k stars 211 forks source link

[Feature] add formula prior #285

Closed GCaptainNemo closed 1 year ago

GCaptainNemo commented 1 year ago

For a given input variable x1, x2, x3, x4 and label y, I have a priori that the target formula has the form of y=f1 (x1, x4)+f2(x2, x3), that is, the input variables are grouped, and the formula has a tree form constraint. Can pysr consider supporting this priori? Thanks.

MilesCranmer commented 1 year ago

Check out https://github.com/MilesCranmer/PySR/pull/276 - I think this supports the use case?

MilesCranmer commented 1 year ago

For posterity, here's how you would do it:

objective = """
function my_custom_objective(tree, dataset::Dataset{T}, options) where {T<:Real}
    # Require root node to be binary, so we can split it,
    # otherwise return a large loss:
    tree.degree != 2 && return T(10000)

    f1 = tree.l
    f2 = tree.r

    # Evaluate f1:
    f1_value, flag = eval_tree_array(f1, dataset.X, options)
    !flag && return T(10000)

    # Evaluate f2:
    f2_value, r_flag = eval_tree_array(f2, dataset.X, options)
    !flag && return T(10000)

    # Impose functional form:
    prediction = f1_value .+ f2_value

    # See if x2 or x3 in an expression:
    function contains_x2_x3(t)
        if t.degree == 0
            return !t.constant && t.feature in (2, 3)
        elseif t.degree == 1
            return contains_x2_x3(t.l)
        else
            return contains_x2_x3(t.l) || contains_x2_x3(t.r)
        end
    end

    # See if x1 or x4 in an expression:
    function contains_x1_x4(t)
        if t.degree == 0
            return !t.constant && t.feature in (1, 4)
        elseif t.degree == 1
            return contains_x1_x4(t.l)
        else
            return contains_x1_x4(t.l) || contains_x1_x4(t.r)
        end
    end

    f1_violating = contains_x2_x3(f1)
    f2_violating = contains_x1_x4(f2)

    regularization = T(100) * f1_violating + T(100) * f2_violating

    prediction_loss = sum((prediction .- dataset.y) .^ 2) / dataset.n
    return prediction_loss + regularization
end
"""

model = PySRRegressor(
    binary_operators=["*", "+", "-"],
    full_objective=objective
)

It won't completely constrain the expression to be of that form, because it can be good if the genetic algorithm can explore violating expressions. However the final expressions should be of that desired form, because the regularization should punish it.

However you should also note that the returned expression and printed format will not have the form you specified. You will have to manually parse them into that form. In the future perhaps I could look at adding this but it will be a bit tricky to write it generally.

Also, perhaps there's a way to make this easier in the future. e.g., you could just write out your desired functional form... will have to think about how to implement that. For now just write a custom objective like this.

GCaptainNemo commented 1 year ago

Thank you very much! The code and explanation are very clear!

MilesCranmer commented 1 year ago

Awesome. Also note I just pushed a quick fix to that PR. If you tried it and it didn't work, if you try again it should hopefully work now.

In the future I'll make it so you can specify a custom function for printing too.

MilesCranmer commented 1 year ago

Let me know if there are any other issues. Cheers, Miles