sympy / sympy

A computer algebra system written in pure Python
https://sympy.org/
Other
13k stars 4.44k forks source link

solve() infinite loop + how to check if an equation is solveable? #25920

Open rijkdw opened 11 months ago

rijkdw commented 11 months ago

The formula for degrees of freedom for the Welch t-test is given below.

$$ \large \text{df} = \frac{ \left(\frac{v_1}{n_1} +\frac{v_2}{n_2}\right)^2 }{ \frac{v_1^2}{n_1^2(n_1-1)} + \frac{v_2^2}{n_2^2(n_2-1)} } $$

In the past, sympy has not been able to solve this equation for any of the right-hand-side variables; for all four variables, sympy 1.1.1 had raised a NotImplementedError. Recently (using 1.12) it can solve for $v_1$ and $v_2$:

>>> eq = Eq(
          sympify("df"),
          sympify("(v1/n1+v2/n2)^2/(pow(v1,2)/(pow(n1,2)*(n1-1))+pow(v2,2)/(pow(n2,2)*(n2-1)))")
    )

>>> solve(eq, "v1")
[0]: (the latex output)

>>> str(_[0])  # the first solution
[1]: 'n1*v2*(-sqrt(-df*(n1 - 1)*(n2 - 1)*(df - n1 - n2 + 2))*(df - n1 + 1) + (n1 - 1)*(df*n2 - df - n1*n2 + n1 + n2 - 1))/(n2*(df - n1 + 1)*(df*n2 - df - n1*n2 + n1 + n2 - 1))'

>>> # similar for "v2"

Coming to the main point of this issue: solve()-ing for $n_1$ or $n_2$ no longer raises a NotImplementedError, but also does not return a result in any reasonable amount of time (>10 minutes).

The degrees-of-freedom equation served well as an example, but I am looking for a more general approach to this problem, which is sympy erroneously thinking it can solve an equation and letting solve() running forever.

It is desirable that either

  1. I can check if an equation is solveable algebraically for a given symbol; or
  2. sympy can detect the "unsolveability" and still throw a NotImplementedError.

If there is a general solution to either of the above points (so that it can be applied to any future equation), I'd be grateful to hear them. Thank you in advance.

-- Rijk

oscarbenjamin commented 11 months ago

This is slow in checking. You can get the result quickly if you call with check=False:

In [14]: df, v1, v2, n1, n2 = symbols("df, v1, v2, n1, n2")

In [15]: eq = Eq(df, (v2/n2 + v1/n1)**2/(v2**2/(n2**2*(n2 - 1)) + v1**2/(n1**2*(n1 - 1))))

In [16]: sols = solve(eq, n1, check=False)

In [17]: print(sols[0])
0

In [18]: print(sols[1])
-(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)/(3*(sqrt(-4*(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)**3 + (27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(-df*v2**2 + n2*v2**2 - v2**2) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/((-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + 2*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**2)/2 + 27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(2*(-df*v2**2 + n2*v2**2 - v2**2)) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(2*(-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**(1/3)) - (sqrt(-4*(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)**3 + (27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(-df*v2**2 + n2*v2**2 - v2**2) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/((-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + 2*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**2)/2 + 27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(2*(-df*v2**2 + n2*v2**2 - v2**2)) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(2*(-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**(1/3)/3 - (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(3*(-df*v2 + n2*v2 - v2))

In [19]: print(sols[2])
-(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)/(3*(-1/2 - sqrt(3)*I/2)*(sqrt(-4*(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)**3 + (27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(-df*v2**2 + n2*v2**2 - v2**2) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/((-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + 2*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**2)/2 + 27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(2*(-df*v2**2 + n2*v2**2 - v2**2)) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(2*(-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**(1/3)) - (-1/2 - sqrt(3)*I/2)*(sqrt(-4*(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)**3 + (27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(-df*v2**2 + n2*v2**2 - v2**2) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/((-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + 2*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**2)/2 + 27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(2*(-df*v2**2 + n2*v2**2 - v2**2)) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(2*(-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**(1/3)/3 - (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(3*(-df*v2 + n2*v2 - v2))

In [20]: print(sols[3])
-(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)/(3*(-1/2 + sqrt(3)*I/2)*(sqrt(-4*(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)**3 + (27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(-df*v2**2 + n2*v2**2 - v2**2) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/((-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + 2*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**2)/2 + 27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(2*(-df*v2**2 + n2*v2**2 - v2**2)) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(2*(-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**(1/3)) - (-1/2 + sqrt(3)*I/2)*(sqrt(-4*(-3*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)/(-df*v2**2 + n2*v2**2 - v2**2) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**2/(-df*v2 + n2*v2 - v2)**2)**3 + (27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(-df*v2**2 + n2*v2**2 - v2**2) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/((-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + 2*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**2)/2 + 27*(-df*n2**3*v1**2 + df*n2**2*v1**2 - n2**3*v1**2 + n2**2*v1**2)/(2*(-df*v2**2 + n2*v2**2 - v2**2)) - 9*(n2**3*v1**2 - n2**2*v1**2 - 2*n2**2*v1*v2 + 2*n2*v1*v2)*(df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(2*(-df*v2 + n2*v2 - v2)*(-df*v2**2 + n2*v2**2 - v2**2)) + (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)**3/(-df*v2 + n2*v2 - v2)**3)**(1/3)/3 - (df*v2 + 2*n2**2*v1 - 2*n2*v1 - n2*v2 + v2)/(3*(-df*v2 + n2*v2 - v2))

The first of these 4 solutions (n1 = 0) is invalid and would be rejected by checking. Checking the other conditions is slow though.

CC @smichr this is what I meant in https://github.com/sympy/sympy/pull/25912#issuecomment-1819807361. Checking by substituting into the equation and then trying to simplify is a mistake. We can easily check here whether the solution is reduced to zero or not when solving a polynomial system so if we include the denominators in the equations then we do should not need to do complicated checking.

rijkdw commented 11 months ago

Hi @oscarbenjamin, thank you for that, this equation does indeed solve almost instantly with check=False. However, it is my understanding that the check flag is not a universal solution to my more general problem, which is "checking if an equation CAN be solved before solving is initiated".

For example, in a similar issue (#25792), adding check=False to solve() doesn't change the behaviour:

def sympify(expression: str):
    """Convert the given expression into a sympy Expression.
    `locals` work as normal `sympify()`'s `locals` argument does."""
    return sympy.sympify(
        expression, rational=True,
    )

def _convert_equation_base(base_equation: str):
    equation_parts = base_equation.split("=")
    left, right = equation_parts
    eq = sympy.Eq(sympify(left), sympify(right))

    # The names of the unique variables in this equation
    variables = set([symbol.name for symbol in eq.atoms(Symbol)])

    for variable in variables:
        # Try and solve the equation for this variable
        try:
            # solutions in terms of `variable`
            solved_list = solve(eq, variable, simplify=True, check=False)
            print(solved_list)
        except Exception as e:
            print(e)

print(_convert_equation_base(
    "G = exp(((Etwo - Eone) / (1.3806503 * (pow(10, -23)) * T)))"))   # changed to exp() as suggested by smichr in that issue

# never terminates

Is there any way to achieve either of the two points in this issue's description?

Thank you in advance -- Rijk

oscarbenjamin commented 11 months ago

Is there any way to achieve either of the two points in this issue's description?

The points are not well defined:

I can check if an equation is solveable algebraically for a given symbol; or sympy can detect the "unsolveability" and still throw a NotImplementedError.

In all examples shown SymPy has an algorithm to solve this but it just happens to be slow so the equation is solvable algebraically and solve will try to solve it but will take a long time.

rijkdw commented 11 months ago

That's fair, thank you for the concise explanation. I'm happy to close this issue, though I see you've tagged smichr in your initial response, so I'll leave to you the final decision to close or not to close depending on that conversation. Thanks again for your help!

oscarbenjamin commented 11 months ago

I think that there is an issue here which is the fact that these examples are slow. That should be fixed but it requires changing the way that solve works. Currently solve does checking by substituting the solutions into the equations and any denominators but that can be extremely slow. It would be better to handle the checking in a different way as mentioned in the comment I linked above.

smichr commented 11 months ago

checking if an equation CAN be solved before solving is initiated

This seems like a pretty high bar. Sometimes an expression needs a lot of manipulation to get in a form which is recognized as something that can be solved. But for your equation, the presence of a float is what is giving you the troubles since SymPy will (by default) try to convert it to a rational and then find ALL the roots of the exponential equation (exp(n*x)-y has n solutions for x): there will be many solutions when n = Rational(eq.atoms(Float).pop()) = 72429636961654951378944(!)

Here are solutions when replacing the Float with a symbol; but you can also get solutions by passing rational=False.

>>> eq
G - exp(7.2429636961655e+22*(-Eone + Etwo)/T)
>>> feq
G - exp(f0*(-Eone + Etwo)/T)
>>> for i in feq.free_symbols:
...  i,solve(feq,i)
...
(Eone, [T*log(exp(Etwo*f0/T)/G)/f0])
(G, [exp(f0*(-Eone + Etwo)/T)])
(f0, [-T*log(G)/(Eone - Etwo)])
(T, [f0*(-Eone + Etwo)/log(G)])
(Etwo, [T*log(G*exp(Eone*f0/T))/f0])
>>>