sympy / sympy

A computer algebra system written in pure Python
https://sympy.org/
Other
12.82k stars 4.4k forks source link

Simpler results are possible from reduce_inequalities when the variable is an integer #26069

Open asmeurer opened 8 months ago

asmeurer commented 8 months ago

Something like

>>> x = symbols('x', integer=True)
>>> reduce_inequalities(4*x**3 - 4*x**2 + x - 2147483647 > 0)
CRootOf(4*x**3 - 4*x**2 + x - 2147483647, 0) < x

could be simplified to

>>> x > int(CRootOf(4*x**3 - 4*x**2 + x - 2147483647, 0))
x > 813

and in general, when the variable is an integer, simplifications like this that replace complex numeric expressions with the floor or ceiling can be done.

I don't know if it makes more sense to do this in reduce_inequalities or to have a separate function that does this simplification (e.g., it could be done in refine).

Flamefire commented 8 months ago

Minor enhancement: The original issue was https://github.com/pytorch/pytorch/issues/117033 With reduce_inequalities(4*s0**3 - 4*s0**2 + s0 <= 2147483647) yielding s0 <= CRootOf(4*x**3 - 4*x**2 + x - 2147483647, 0) (note the introduction of the sympy symbol x, with the user-defined symbol being s0 (which is the integer))
Not sure if the reuse of x is a problem, but looks odd to me.

asmeurer commented 8 months ago

x in that expression is just a dummy variable. The entire CRootOf expression represents a constant number, i.e. one of the roots of that polynomial. It doesn't depend on any variables.

oscarbenjamin commented 8 months ago

Not sure if the reuse of x is a problem, but looks odd to me.

It is a bound symbol so its name should not affect anything:

In [28]: rootof(x**3 + x - 1, 0)
Out[28]: 
       ⎛ 3           ⎞
CRootOf⎝x  + x - 1, 0⎠

In [29]: rootof(y**3 + y - 1, 0)
Out[29]: 
       ⎛ 3           ⎞
CRootOf⎝x  + x - 1, 0⎠
oscarbenjamin commented 8 months ago

I don't know if it makes more sense to do this in reduce_inequalities or to have a separate function

Ideally we would have a reduce function that has modes for integer, real and complex (like Mathematica's reduce).

In general though I expect that the sorts of thing that pytorch is doing are likely more specific than what you would want a general reduce function for. SymPy functions like solve, reduce_inequalities etc are designed according to the following premises:

  1. Users can throw anything in.
  2. The function will try as hard as possible to return the most reduced answer.

When you use SymPy interactively this is okay because you can just try things and then use Ctrl-C if it turns out to be too slow or if the answer is very complicated you can try something else. When you use something like solve as an API this sort of behaviour is not what you want.

What would be more suitable for pytorch is most likely an alternative version of reduce_inequalities whose goal is explicitly stated to be that it only solves simple problems quickly or otherwise gives up. Actually much of SymPy would benefit from being able to use such functions internally as well.

How does pytorch obtain this inequality?

Is it guaranteed to be a univariate polynomial with integer coefficients?

If I knew that I wanted something very specific like the smallest integer such that a given univariate polynomial with integer coefficients is positive I would not use a general function like reduce_inequalities at all. SymPy provides excellent primitives to make a dedicated function that can solve that problem e.g.:

In [41]: p = 4*x**3 - 4*x**2 + x - 2147483647

In [42]: [int(r) for r in real_roots(p)]
Out[42]: [813]

In [43]: p.subs(x, 813)
Out[43]: -655522

In [44]: p.subs(x, 814)
Out[44]: 7279359

This is enough to prove that over the integers p <= 0 implies that x <= 813.

oscarbenjamin commented 8 months ago

To be clear I am not suggesting that SymPy should not add something that could solve this problem. My point is just usecases like pytorch are most likely better served by functions that are explicitly limited in scope rather than general functions like reduce_inequalities.

Many SymPy functions are by design expected to attempt to solve problems that are not solvable in full generality. Usually in practice though it is better to use a function whose scope is restricted to a well defined fully solvable class of problems. It would be better if more SymPy functions had an explicitly documented limited scope and if contributors would not try to expand the scope to include problem classes that are not solvable in general (in reasonable time for reasonable inputs).

We could have e.g. reduce_integer_inequalities and then require very limiting constraints on the allowable set of inputs so that the problem is always well defined and can be expected to complete in a reasonable time for any simple inputs.

Another possibility would be a Poly method. There is already intervals:

In [50]: Poly(4*x**3 - 4*x**2 + x - 2147483647).intervals()
Out[50]: [((813, 814), 1)]

This is intended to find intervals that enclose the roots but we could also have a method to define intervals with either integer or RootOf boundaries such that the polynomials is positive or negative (it wouldn't surprise me if such a function already existed somewhere in sympy.polys).

The intended design needs to be clear up front though and there are many questions in practice that are not usually considered when deciding a SymPy API:

  1. Must all solutions be found?
  2. Must any solution be the best possible solution?
  3. What if the problem is underdetermined (do you just want any numeric solution or a full parametrisation of the complete solution set)?
  4. Are approximate solutions acceptable?
  5. What is a reasonable range of input sizes for the problem at hand (e.g. size of integers, degree of polynomials, number of variables, etc)?

These things are not usually considered when designing high-level SymPy API because the premise is always that a function should be expected to solve the most general class of problems in the most general possible way regardless of arbitrary slowness and regardless of whether or not the posed problem space is even solvable (or useful).

oscarbenjamin commented 8 months ago

Here are candidate functions/methods that could be added in polys:

def choose_rational(a, b):
    """Choose a rational number in the open interval (a, b).

    Parameters
    ==========

    a : Expr, a numeric expression.
        Lower bound of the interval.
    b : Expr, a numeric expression.
        Upper bound of the interval.

    Returns
    =======

    c : Rational, a rational number.
        If possible then c will be an integer.

    Examples
    ========

    >>> from sympy import sqrt
    >>> choose_rational(0, 1)
    1/2
    >>> choose_rational(0, 1/3)
    1/4
    >>> choose_rational(sqrt(2), sqrt(3))
    3/2

    Notes
    =====

    The interval (a, b) must be non-empty. The end points a and b can be
    infinite. If they are not infinite then they must either be explicit
    rational numbers or expressions representing (real) irrational numbers.
    Roots returned by real_roots will satisfy these conditions.
    """
    if a < S.Zero < b:
        return S.Zero
    elif a is S.NegativeInfinity:
        return ceiling(b) - 1
    elif b is S.Infinity:
        return floor(a) + 1

    inf = floor(a)
    sup = ceiling(b)

    if not (sup > inf):
        raise ValueError("the interval (%s, %s) is empty" % (a, b))

    # Return an integer if possible
    if inf - sup >= 2:
        return (inf + sup) // 2

    # Bisect to find a rational number with power of 2 denominator
    midpoint = (inf + sup) / 2

    while not (a < midpoint < b):
        if midpoint < a:
            inf = midpoint
        else:
            sup = midpoint
        midpoint = (inf + sup) / 2

    return midpoint

def poly_intervals(p, inf=-oo, sup=oo):
    """Open real intervals where polynomial p is positive or negative.

    Parameters
    ==========

    p : Poly/Expr
        A univariate polynomial with rational or floating point coefficients.

    inf : rational/float/-oo
        Lower bound of the range of interest.

    sup : rational/float/+oo
        Upper bound of the range of interest.

    Examples
    ========

    >>> from sympy.abc import x
    >>> p = x**2 - 3
    >>> poly_intervals(p)
    [(-oo, -sqrt(3), "+"), (-sqrt(3), sqrt(3), "-"), (sqrt(3), oo, "+")]
    >>> poly_intervals(p, 0, 10)
    [(0, sqrt(3), "-"), (sqrt(3), 10, "+")]
    """
    p = Poly(p)

    if not p.domain in (ZZ, QQ, RR):
        raise NotImplementedError("can't compute sign of %s" % p)
    if not p.is_univariate:
        raise MultivariatePolynomialError(p)

    if p.domain == RR:
        p = p.to_exact()

    if p.is_zero:
        return []

    roots = sorted(set(p.real_roots()))

    while roots and roots[0] <= inf:
        roots.pop(0)
    while roots and roots[-1] >= sup:
        roots.pop()

    roots = [inf] + roots + [sup]

    intervals = []

    for a, b in zip(roots[:-1], roots[1:]):
        q = choose_rational(a, b)
        sign = '+' if p(q) > 0 else '-'
        intervals.append((a, b, sign))

    return intervals

def poly_intervals_integer(p, inf=-oo, sup=oo):
    """Closed integer intervals where polynomial p is positive or negative.

    Examples
    ========

    >>> from sympy.abc import x
    >>> p = x**2 - 3
    >>> poly_intervals_integer(p)
    [(-oo, -2, "+"), (-1, 1, "-"), (2, oo, "+")]
    >>> poly_intervals(4*x**3 - 4*x**2 + x - 2147483647, 0, oo)
    [(0, 813, "-"), (814, oo, "+")]
    """
    intervals = []

    for a, b, sign in poly_intervals(p, inf, sup):
        a = ceiling(a)
        b = floor(b)
        if a <= b:
            intervals.append((a, b, sign))

    return intervals