sympy / sympy

A computer algebra system written in pure Python
https://sympy.org/
Other
12.97k stars 4.43k forks source link

Integration fails with "Invalid NaN comparison" #21066

Open petschge opened 3 years ago

petschge commented 3 years ago

Running the following code

#!/usr/bin/env python3
import sympy
from sympy import Eq, exp, integrate, oo, pi, simplify, sin, cos, solve, sqrt

# define symbols
a,mu,sigma,sigma2,phi,vperp = sympy.symbols("a mu sigma sigma2 phi vperp", positive=True)
vpara = sympy.symbols("vpara", real=True)

# drifting bi-Maxwellian with unknown normalization
h = a * exp(- (vperp*sin(phi))**2 / (2 * sigma**2)) * exp(- (vperp*cos(phi))**2 / (2 * sigma**2)) * exp(- (vpara-mu)**2 / (2 * sigma2**2))
# the vperp comes from the volume element in phi-vpara-vperp space
I1 = integrate(integrate(integrate( h*vperp , (phi,0,2*pi) ), (vperp,0,oo) ), (vpara,-oo,oo))
# compute normalization
asol = solve(Eq(I1, 1), a)[0]
# correctly normalized bi-Maxwellian with drift
f = simplify(h.subs(a, asol))
# compute moment of distribution function related to the thermal force
I2 = integrate(integrate(integrate(vpara/sqrt(vpara**2+vperp**2) * 1/(vpara**2+vperp**2) *  f*vperp , (phi,0,2*pi) ), (vperp,0,oo) ), (vpara,-oo,oo))
# we don't even get here
print(I2)

results in the following error message after a few minutes

Traceback (most recent call last):
  File "./minimal3.py", line 18, in <module>
    I2 = integrate(integrate(integrate(vpara/sqrt(vpara**2+vperp**2) * 1/(vpara**2+vperp**2) *  f*vperp , (phi,0,2*pi) ), (vperp,0,oo) ), (vpara,-oo,oo))
  File "~/.local/lib/python3.7/site-packages/sympy/integrals/integrals.py", line 1544, in integrate
    return integral.doit(**doit_flags)
  File "~/.local/lib/python3.7/site-packages/sympy/integrals/integrals.py", line 594, in doit
    function, xab[0], **eval_kwargs)
  File "~/.local/lib/python3.7/site-packages/sympy/integrals/integrals.py", line 1079, in _eval_integral
    result = manualintegrate(g, x)
  File "~/.local/lib/python3.7/site-packages/sympy/integrals/manualintegrate.py", line 1654, in manualintegrate
    result = _manualintegrate(integral_steps(f, var))
  File "~/.local/lib/python3.7/site-packages/sympy/integrals/manualintegrate.py", line 1329, in integral_steps
    fallback_rule)(integral)
  File "~/.local/lib/python3.7/site-packages/sympy/strategies/core.py", line 85, in do_one_rl
    result = rl(expr)
  File "~/.local/lib/python3.7/site-packages/sympy/strategies/core.py", line 85, in do_one_rl
    result = rl(expr)
  File "~/.local/lib/python3.7/site-packages/sympy/strategies/core.py", line 65, in null_safe_rl
    result = rule(expr)
  File "~/.local/lib/python3.7/site-packages/sympy/integrals/manualintegrate.py", line 321, in _alternatives
    result = rule(integral)
  File "~/.local/lib/python3.7/site-packages/sympy/strategies/core.py", line 33, in conditioned_rl
    return rule(expr)
  File "~/.local/lib/python3.7/site-packages/sympy/integrals/manualintegrate.py", line 611, in parts_rule
    result = _parts_rule(integrand, symbol)
  File "~/.local/lib/python3.7/site-packages/sympy/integrals/manualintegrate.py", line 595, in _parts_rule
    if r and r[0].subs(dummy, 1).equals(dv):
  File "~/.local/lib/python3.7/site-packages/sympy/core/expr.py", line 743, in equals
    constant = diff.is_constant(simplify=False, failing_number=True)
  File "~/.local/lib/python3.7/site-packages/sympy/core/expr.py", line 651, in is_constant
    simultaneous=True)
  File "~/.local/lib/python3.7/site-packages/sympy/core/basic.py", line 944, in subs
    return rv.xreplace(reps)
  File "~/.local/lib/python3.7/site-packages/sympy/core/basic.py", line 1138, in xreplace
    value, _ = self._xreplace(rule)
  File "~/.local/lib/python3.7/site-packages/sympy/core/basic.py", line 1153, in _xreplace
    a_xr = _xreplace(rule)
  File "~/.local/lib/python3.7/site-packages/sympy/core/basic.py", line 1153, in _xreplace
    a_xr = _xreplace(rule)
  File "~/.local/lib/python3.7/site-packages/sympy/core/basic.py", line 1153, in _xreplace
    a_xr = _xreplace(rule)
  [Previous line repeated 3 more times]
  File "~/.local/lib/python3.7/site-packages/sympy/core/basic.py", line 1160, in _xreplace
    return self.func(*args), True
  File "~/.local/lib/python3.7/site-packages/sympy/core/relational.py", line 702, in __new__
    raise TypeError("Invalid NaN comparison")
TypeError: Invalid NaN comparison
oscarbenjamin commented 3 years ago

A simpler way to reproduce the bug is

In [57]: e = Piecewise((2, Abs(arg(x)) < pi), (1, True))

In [58]: e
Out[58]: 
⎧2  for │arg(x)│ < π
⎨                   
⎩1     otherwise    

In [59]: e.is_constant()
---------------------------------------------------------------------------
TypeError 
ghost commented 3 years ago

A simpler way to reproduce the bug is

In [57]: e = Piecewise((2, Abs(arg(x)) < pi), (1, True))

In [58]: e
Out[58]: 
⎧2  for │arg(x)│ < π
⎨                   
⎩1     otherwise    

In [59]: e.is_constant()
---------------------------------------------------------------------------
TypeError 

Actually, the problem lies here

>>> arg(0)
nan

is_constant tries to check if the values are different at two points. One of the selected points, as can be seen, is 0 . Thus this is bound to give an error. https://github.com/sympy/sympy/blob/2346054bb4888ef7eec2f6dad6c3dd52bf1fe927/sympy/core/expr.py#L657-L659

A possible solutions is to set arg(0) to 0 . Though I don't know if this change would be backwards compatible.

oscarbenjamin commented 3 years ago

We could change arg(0) to be zero but it still remains the case that the is_constant method is unreliable for trying to substitute random values into an unknown expression and needs to be made more robust.

EricWay1024 commented 3 years ago

https://github.com/sympy/sympy/blob/b564d9ba2705ee2978766e1ee2102750834df68b/sympy/core/expr.py#L657-L664 The original code has considered the case when substitution fails and tries to catch such exceptions. The only problem is that it assumes only ZeroDivisionError may occur in the procedure, while in this case, a TypeError has been raised and thus not caught. Adding this type of exception to the code may be a quick fix.

oscarbenjamin commented 3 years ago

The problem is that TypeError can be raised for all sorts of things including bugs in the code:

In [20]: def f(x): return x**2

In [21]: f()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-21-c43e34e6d405> in <module>
----> 1 f()

TypeError: f() missing 1 required positional argument: 'x'

We need more specific exception classes. It would be much better to do except NaNError than except TypeError.

In any case when adding any kind of except clause like this there needs to be clear comment explaining why it is being caught and what are examples of input that lead to the exception being caught.

EricWay1024 commented 3 years ago

Then I guess we should change this exception here raised to a more specified exception class: https://github.com/sympy/sympy/blob/b564d9ba2705ee2978766e1ee2102750834df68b/sympy/core/relational.py#L672-L677 But we don't have a NaNError yet, and for consistency there would be other places we may have to change so as to raise/catch this new exception.

By the way, though I agree that a overly general exception should not be caught, I'm not sure about the reason why usages of except TypeError can be found in 100+ places all over the repo. I personally feel it is also okay to use it here.

oscarbenjamin commented 3 years ago

By the way, though I agree that a overly general exception should not be caught, I'm not sure about the reason why usages of except TypeError can be found in 100+ places all over the repo. I personally feel it is also okay to use it here.

I think most of those places should be changed and that we should really try to cleanup exception handling.

The first problem though is the number of places that raise TypeError. If TypeError is raised then it's not possible to catch something more specific.

What we could do is this:

# sympy.core.coreerrors

class NaNError(BaseCoreError):
    """Raised for invalid operations involving NaN"""
    pass

class NaNTypeError(NaNError, TypeError):
    """Raised for a NaNError that previously raised TypeError

    This is raised for backwards compatibility in places where TypeError
    was previously raised so that code that previously caught the TypeError
    can still catch the exception.

    New code should catch NaNError rather than NaNTypeError.
    """
    pass

With those we can change the places that raise TypeError for invalid operations involving NaN so that they now raise NaNTypeError. Then in the places where we want to catch the exception we can catch NaNError.