Open asmeurer opened 8 months ago
Minor enhancement: The original issue was https://github.com/pytorch/pytorch/issues/117033
With
reduce_inequalities(4*s0**3 - 4*s0**2 + s0 <= 2147483647)
yielding s0 <= CRootOf(4*x**3 - 4*x**2 + x - 2147483647, 0)
(note the introduction of the sympy symbol x
, with the user-defined symbol being s0
(which is the integer))
Not sure if the reuse of x
is a problem, but looks odd to me.
x in that expression is just a dummy variable. The entire CRootOf expression represents a constant number, i.e. one of the roots of that polynomial. It doesn't depend on any variables.
Not sure if the reuse of
x
is a problem, but looks odd to me.
It is a bound symbol so its name should not affect anything:
In [28]: rootof(x**3 + x - 1, 0)
Out[28]:
⎛ 3 ⎞
CRootOf⎝x + x - 1, 0⎠
In [29]: rootof(y**3 + y - 1, 0)
Out[29]:
⎛ 3 ⎞
CRootOf⎝x + x - 1, 0⎠
I don't know if it makes more sense to do this in reduce_inequalities or to have a separate function
Ideally we would have a reduce
function that has modes for integer, real and complex (like Mathematica's reduce).
In general though I expect that the sorts of thing that pytorch is doing are likely more specific than what you would want a general reduce function for. SymPy functions like solve
, reduce_inequalities
etc are designed according to the following premises:
When you use SymPy interactively this is okay because you can just try things and then use Ctrl-C if it turns out to be too slow or if the answer is very complicated you can try something else. When you use something like solve
as an API this sort of behaviour is not what you want.
What would be more suitable for pytorch is most likely an alternative version of reduce_inequalities
whose goal is explicitly stated to be that it only solves simple problems quickly or otherwise gives up. Actually much of SymPy would benefit from being able to use such functions internally as well.
How does pytorch obtain this inequality?
Is it guaranteed to be a univariate polynomial with integer coefficients?
If I knew that I wanted something very specific like the smallest integer such that a given univariate polynomial with integer coefficients is positive I would not use a general function like reduce_inequalities
at all. SymPy provides excellent primitives to make a dedicated function that can solve that problem e.g.:
In [41]: p = 4*x**3 - 4*x**2 + x - 2147483647
In [42]: [int(r) for r in real_roots(p)]
Out[42]: [813]
In [43]: p.subs(x, 813)
Out[43]: -655522
In [44]: p.subs(x, 814)
Out[44]: 7279359
This is enough to prove that over the integers p <= 0
implies that x <= 813
.
To be clear I am not suggesting that SymPy should not add something that could solve this problem. My point is just usecases like pytorch are most likely better served by functions that are explicitly limited in scope rather than general functions like reduce_inequalities
.
Many SymPy functions are by design expected to attempt to solve problems that are not solvable in full generality. Usually in practice though it is better to use a function whose scope is restricted to a well defined fully solvable class of problems. It would be better if more SymPy functions had an explicitly documented limited scope and if contributors would not try to expand the scope to include problem classes that are not solvable in general (in reasonable time for reasonable inputs).
We could have e.g. reduce_integer_inequalities
and then require very limiting constraints on the allowable set of inputs so that the problem is always well defined and can be expected to complete in a reasonable time for any simple inputs.
Another possibility would be a Poly method. There is already intervals:
In [50]: Poly(4*x**3 - 4*x**2 + x - 2147483647).intervals()
Out[50]: [((813, 814), 1)]
This is intended to find intervals that enclose the roots but we could also have a method to define intervals with either integer or RootOf boundaries such that the polynomials is positive or negative (it wouldn't surprise me if such a function already existed somewhere in sympy.polys).
The intended design needs to be clear up front though and there are many questions in practice that are not usually considered when deciding a SymPy API:
These things are not usually considered when designing high-level SymPy API because the premise is always that a function should be expected to solve the most general class of problems in the most general possible way regardless of arbitrary slowness and regardless of whether or not the posed problem space is even solvable (or useful).
Here are candidate functions/methods that could be added in polys:
def choose_rational(a, b):
"""Choose a rational number in the open interval (a, b).
Parameters
==========
a : Expr, a numeric expression.
Lower bound of the interval.
b : Expr, a numeric expression.
Upper bound of the interval.
Returns
=======
c : Rational, a rational number.
If possible then c will be an integer.
Examples
========
>>> from sympy import sqrt
>>> choose_rational(0, 1)
1/2
>>> choose_rational(0, 1/3)
1/4
>>> choose_rational(sqrt(2), sqrt(3))
3/2
Notes
=====
The interval (a, b) must be non-empty. The end points a and b can be
infinite. If they are not infinite then they must either be explicit
rational numbers or expressions representing (real) irrational numbers.
Roots returned by real_roots will satisfy these conditions.
"""
if a < S.Zero < b:
return S.Zero
elif a is S.NegativeInfinity:
return ceiling(b) - 1
elif b is S.Infinity:
return floor(a) + 1
inf = floor(a)
sup = ceiling(b)
if not (sup > inf):
raise ValueError("the interval (%s, %s) is empty" % (a, b))
# Return an integer if possible
if inf - sup >= 2:
return (inf + sup) // 2
# Bisect to find a rational number with power of 2 denominator
midpoint = (inf + sup) / 2
while not (a < midpoint < b):
if midpoint < a:
inf = midpoint
else:
sup = midpoint
midpoint = (inf + sup) / 2
return midpoint
def poly_intervals(p, inf=-oo, sup=oo):
"""Open real intervals where polynomial p is positive or negative.
Parameters
==========
p : Poly/Expr
A univariate polynomial with rational or floating point coefficients.
inf : rational/float/-oo
Lower bound of the range of interest.
sup : rational/float/+oo
Upper bound of the range of interest.
Examples
========
>>> from sympy.abc import x
>>> p = x**2 - 3
>>> poly_intervals(p)
[(-oo, -sqrt(3), "+"), (-sqrt(3), sqrt(3), "-"), (sqrt(3), oo, "+")]
>>> poly_intervals(p, 0, 10)
[(0, sqrt(3), "-"), (sqrt(3), 10, "+")]
"""
p = Poly(p)
if not p.domain in (ZZ, QQ, RR):
raise NotImplementedError("can't compute sign of %s" % p)
if not p.is_univariate:
raise MultivariatePolynomialError(p)
if p.domain == RR:
p = p.to_exact()
if p.is_zero:
return []
roots = sorted(set(p.real_roots()))
while roots and roots[0] <= inf:
roots.pop(0)
while roots and roots[-1] >= sup:
roots.pop()
roots = [inf] + roots + [sup]
intervals = []
for a, b in zip(roots[:-1], roots[1:]):
q = choose_rational(a, b)
sign = '+' if p(q) > 0 else '-'
intervals.append((a, b, sign))
return intervals
def poly_intervals_integer(p, inf=-oo, sup=oo):
"""Closed integer intervals where polynomial p is positive or negative.
Examples
========
>>> from sympy.abc import x
>>> p = x**2 - 3
>>> poly_intervals_integer(p)
[(-oo, -2, "+"), (-1, 1, "-"), (2, oo, "+")]
>>> poly_intervals(4*x**3 - 4*x**2 + x - 2147483647, 0, oo)
[(0, 813, "-"), (814, oo, "+")]
"""
intervals = []
for a, b, sign in poly_intervals(p, inf, sup):
a = ceiling(a)
b = floor(b)
if a <= b:
intervals.append((a, b, sign))
return intervals
Something like
could be simplified to
and in general, when the variable is an integer, simplifications like this that replace complex numeric expressions with the floor or ceiling can be done.
I don't know if it makes more sense to do this in reduce_inequalities or to have a separate function that does this simplification (e.g., it could be done in
refine
).