sagemath / sage

Main repository of SageMath
https://www.sagemath.org
Other
1.35k stars 461 forks source link

Very slow `sage.rings.asymptotic.asymptotics_multivariate_generating_function.diff_op` #35207

Open tornaria opened 1 year ago

tornaria commented 1 year ago

Is there an existing issue for this?

Did you read the documentation and troubleshoot guide?

Environment

- **OS**: void
- **Sage Version**: 10.0.beta2

Steps To Reproduce

sage: from sage.rings.asymptotic.asymptotics_multivariate_generating_functions import diff_op
sage: T = var('x, y')
sage: A = function('A')(*tuple(T))
sage: B = function('B')(*tuple(T))
sage: AB_derivs = {}
sage: M = matrix([[1, 2],[2, 1]])
sage: %time DF = diff_op(A, B, AB_derivs, T, M, 1, 2)
CPU times: user 54.1 s, sys: 148 ms, total: 54.3 s
Wall time: 50.4 s

Expected Behavior

It takes less than 10ms.

Actual Behavior

It takes more than 50s.

Additional Information

This line

                    if product_derivs[idx] != ZZ.zero():

is run for several symbolic expressions taking a significant time (up to ~ 2s for some expressions).

On one hand, this test is repeated several times for the same expression. If each expression is tested only once the total time goes from ~54s down to ~12s.

On the other hand, this is still slow: if this test is removed the total time goes down to ~ 5ms.

Here's a proof of concept:

--- a/src/sage/rings/asymptotic/asymptotics_multivariate_generating_functions.py
+++ b/src/sage/rings/asymptotic/asymptotics_multivariate_generating_functions.py
@@ -3978,7 +3978,7 @@ def diff_all(f, V, n, ending=[], sub=None, sub_final=None,
     return derivs

-def diff_op(A, B, AB_derivs, V, M, r, N):
+def diff_op(A, B, AB_derivs, V, M, r, N, check_zero=True):
     r"""
     Return the derivatives `DD^{(l+k)}(A[j] B^l)` evaluated at a point
     `p` for various natural numbers `j, k, l` which depend on `r` and `N`.
@@ -4048,6 +4048,7 @@ def diff_op(A, B, AB_derivs, V, M, r, N):
             for l in range(2 * k + 1):
                 for s in combinations_with_replacement(V, 2 * (k + l)):
                     DF = diff(A[j] * B ** l, list(s)).subs(AB_derivs)
+                    if check_zero and DF.is_zero(): DF = ZZ.zero()
                     product_derivs[(j, k, l) + s] = DF

     # Second, compute DD^(k+l)(A[j]*B^l)(p) and store values in dictionary.
@@ -4067,7 +4068,7 @@ def diff_op(A, B, AB_derivs, V, M, r, N):
                 diffo = ZZ.zero()
                 for t in P:
                     idx = (j, k, l) + diff_seq(V, t)
-                    if product_derivs[idx] != ZZ.zero():
+                    if product_derivs[idx] is not ZZ.zero():
                         MM = ZZ.one()
                         for (a, b) in t:
                             MM *= M[a][b]

With this in place:

sage: from sage.rings.asymptotic.asymptotics_multivariate_generating_functions import diff_op
sage: T = var('x, y')
sage: A = function('A')(*tuple(T))
sage: B = function('B')(*tuple(T))
sage: AB_derivs = {}
sage: M = matrix([[1, 2],[2, 1]])
sage: %time D1 = diff_op(A, B, AB_derivs, T, M, 1, 2)
CPU times: user 12.8 s, sys: 79 ms, total: 12.9 s
Wall time: 11.7 s
sage: %time D2 = diff_op(A, B, AB_derivs, T, M, 1, 2, check_zero=False)
CPU times: user 4.88 ms, sys: 1 µs, total: 4.88 ms
Wall time: 4.88 ms
sage: D1 == D2
True

If I understand the code, removing the test will still give a correct answer, maybe including some subexpressions that vanish. An alternative would be to simplify at the end, which takes < 2s in this case, although it doesn't seem to change anything.

@tscrim you are the last person to touch this code, or else you may be able to point out who knows about it. If it were up to me, I'd just remove the check for zero. In any other case, the test has to be marked # long time of course.

tscrim commented 1 year ago

I am not sure how good of a practice it is to base such decisions on certain doctests. This could cause problems for expressions that are 0 but otherwise look very complicated, e.g.,

sage: sin(x)^2 + cos(x)^2 - 1
cos(x)^2 + sin(x)^2 - 1

Yet I agree the result will be the same.

I am not sure if we can guarantee the result will be in SR, but there is an is_trivial_zero() that we could use. It won't deal with things that are 0 but expressed in a complicated way. We could safeguard this by checking that the parent of the left-hand side is SR though. This will have nearly all of the speed of avoiding complicated SR checks (although Maxima should be faster at this) but also avoid things that are trivially 0.