sympy / sympy

A computer algebra system written in pure Python
https://sympy.org/
Other
12.91k stars 4.42k forks source link

`expand()` does not expand matrix powers #25554

Open PauliusSasnauskas opened 1 year ago

PauliusSasnauskas commented 1 year ago

When working with matrices I expect the expand() command to expand matrix powers (example use case: checking if two matrix expressions are equal).

Example:

import sympy as sp

n, m = sp.symbols('n m')
A = sp.MatrixSymbol('A', n, m)

result = (A.T @ A) @ (A.T @ A) - A.T @ A @ A.T @ A
print(result)

I expect the result to be 0 based on matmul associativity, but it is giving me (U.T*U)**2 - U.T*U*U.T*U. Any combination of expand(), doit(), or simplify() does not change the output.

For some cases (although not this one), the following hack seems to do the trick:

terms, (expr,) = sp.cse(result)
result = expr.subs(terms).doit().expand()

with an arbitrary number and combination of subs(terms), expand(), doit(), and simplify() for the result.

I have not found any way to expand matrix powers. Anyone have an idea?

Python 3.11.4 Sympy 1.12

oscarbenjamin commented 1 year ago

I thought that factor_nc might work but it fails:

In [8]: factor_nc(result)
---------------------------------------------------------------------------
...
TypeError: Mix of Matrix and Scalar symbols

It is not completely clear what the expected end result of expand would be here. I suppose it is by analogy with this:

In [14]: x, y = symbols('x, y', commutative=False)

In [15]: (x*y)**2
Out[15]: 
     2
(x⋅y) 

In [16]: expand((x*y)**2)
Out[16]: x⋅y⋅x⋅y

In [17]: (x*y)**2 - x*y*x*y
Out[17]: 
                2
-x⋅y⋅x⋅y + (x⋅y) 

In [18]: expand((x*y)**2 - x*y*x*y)
Out[18]: 0

Probably factor_nc should be improved for powers as well:

In [19]: factor_nc((x*y)**2 - x*y*x*y)
Out[19]: 
                2
-x⋅y⋅x⋅y + (x⋅y) 
sylee957 commented 1 year ago

It's not very difficult to implement the logic about matrix power expansion, but the problem was finding what kind of options that 'fits' there. https://github.com/sympy/sympy/blob/73484029c5f8a02da1a6f91419f39c6ed4cfdf57/sympy/core/function.py#L2486-L2487

There is nothing about _eval_expand_... implemented for matrix expressions yet

asmeurer commented 1 year ago

expand_power_base expands (x*y)**a into x**a * y**a. The equivalent for noncommutatives is to expand out an integer power, like (x*y)**2 -> x*y*x*y, so I'd say it should do that (probably with a default limit on the size of the exponent that gets expanded).

sylee957 commented 1 year ago

I would also like to fill out other options, like mul or multinomial, according to matrix.

asmeurer commented 1 year ago

I think the more general problem is that things that work on noncommutative expressions don't automatically work for matrix expressions:

>>> A, B = symbols("A B", commutative=False)
>>> expand((A + B)**2)
A*B + A**2 + B*A + B**2
>>> A, B = MatrixSymbol("A", n, n), MatrixSymbol("B", n, n)
>>> expand((A + B)**2)
(A + B)**2

Ideally we shouldn't have to reimplement everything twice on both MatrixExpr and commutative=False Expr.

asmeurer commented 1 year ago

By the way, the factor_nc issue with matrix expressions is https://github.com/sympy/sympy/issues/24980

sylee957 commented 1 year ago

I think that this fixes the problem

    _eval_expand_power_base = Pow._eval_expand_power_base
    _eval_expand_multinomial = Pow._eval_expand_multinomial

but that may not be complete fix because there are some suspicious logic like rational powers that may not make sense.

Ideally we shouldn't have to reimplement everything twice on both MatrixExpr and commutative=False Expr.

however, it shoudn't mean that the code should be reused by inheritance though.

asmeurer commented 1 year ago

MatPow is the only class that doesn't inherit from the respective core class. MatAdd and MatMul already subclass from Add and Mul. I don't know if there's a good reason for that. That's the only reason why this doesn't work.

For things like factor_nc, the issue comes from the fact that matrices aren't exactly like commutative=False symbols, because they are more restrictive on what they can do. You can't combine matrix expressions unless their shapes match, and more importantly, you can't add a matrix expression to a non-matrix expression or multiply a matrix expression with a noncommutative non-matrix expression.

However, I would say that those restrictions are there to protect making nonsense with matrix expressions, and the fact that they are there means that we actually can just reuse core logic for matrix expressions, because it will automatically fail if it tries to do something that doesn't make sense for matrices.

sylee957 commented 1 year ago

the other reason is not to do multiple inheritance.

sylee957 commented 1 year ago
================================================================================= FAILURES =================================================================================
____________________________________________________________________ test_matrix_derivatives_of_traces _____________________________________________________________________

>   ???
E   assert ArrayAdd(PermuteDims(ArrayTensorProduct(X.T, I), (1 2 3)), PermuteDims(ArrayTensorProduct(I, X), (1 2 3))) == (2 * X.T)
E    +  where ArrayAdd(PermuteDims(ArrayTensorProduct(X.T, I), (1 2 3)), PermuteDims(ArrayTensorProduct(I, X), (1 2 3))) = <bound method Expr.diff of Trace(X**2)>(X)
E    +    where <bound method Expr.diff of Trace(X**2)> = Trace(X**2).diff
E    +  and   X.T = X.T

/workspaces/sympy/sympy/matrices/expressions/tests/test_derivatives.py:233: AssertionError
______________________________________________________________________ test_derivatives_matrix_norms _______________________________________________________________________

>   ???

/workspaces/sympy/sympy/matrices/expressions/tests/test_derivatives.py:369: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
sympy/core/expr.py:3586: in diff
    return _derivative_dispatch(self, *symbols, **assumptions)
sympy/core/function.py:1907: in _derivative_dispatch
    return ArrayDerivative(expr, *variables, **kwargs)
sympy/tensor/array/array_derivatives.py:19: in __new__
    obj = super().__new__(cls, expr, *variables, **kwargs)
sympy/core/function.py:1436: in __new__
    obj = cls._dispatch_eval_derivative_n_times(expr, v, count)
sympy/tensor/array/array_derivatives.py:118: in _dispatch_eval_derivative_n_times
    result = cls._call_derive_default(expr, v)
sympy/tensor/array/array_derivatives.py:77: in _call_derive_default
    return _matrix_derivative(expr, v)
sympy/matrices/expressions/matexpr.py:537: in _matrix_derivative
    array_expr = convert_matrix_to_array(expr)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

expr = (x.T*y)**(1/2)

    def convert_matrix_to_array(expr: Basic) -> Basic:
        if isinstance(expr, MatMul):
            args_nonmat = []
            args = []
            for arg in expr.args:
                if isinstance(arg, MatrixExpr):
                    args.append(arg)
                else:
                    args_nonmat.append(convert_matrix_to_array(arg))
            contractions = [(2*i+1, 2*i+2) for i in range(len(args)-1)]
            scalar = _array_tensor_product(*args_nonmat) if args_nonmat else S.One
            if scalar == 1:
                tprod = _array_tensor_product(
                    *[convert_matrix_to_array(arg) for arg in args])
            else:
                tprod = _array_tensor_product(
                    scalar,
                    *[convert_matrix_to_array(arg) for arg in args])
            return _array_contraction(
                    tprod,
                    *contractions
            )
        elif isinstance(expr, MatAdd):
            return _array_add(
                    *[convert_matrix_to_array(arg) for arg in expr.args]
            )
        elif isinstance(expr, Transpose):
            return _permute_dims(
                    convert_matrix_to_array(expr.args[0]), [1, 0]
            )
        elif isinstance(expr, Trace):
            inner_expr: MatrixExpr = convert_matrix_to_array(expr.arg) # type: ignore
            return _array_contraction(inner_expr, (0, len(inner_expr.shape) - 1))
        elif isinstance(expr, Mul):
            return _array_tensor_product(*[convert_matrix_to_array(i) for i in expr.args])
        elif isinstance(expr, Pow):
            base = convert_matrix_to_array(expr.base)
            if (expr.exp > 0) == True:
>               return _array_tensor_product(*[base for i in range(expr.exp)])
E               TypeError: 'Half' object cannot be interpreted as an integer

sympy/tensor/array/expressions/from_matrix_to_array.py:59: TypeError
                                                                              DO *NOT* COMMIT!                                                                              
========================================================================= short test summary info ==========================================================================
FAILED sympy/matrices/expressions/tests/test_derivatives.py::test_matrix_derivatives_of_traces - assert ArrayAdd(PermuteDims(ArrayTensorProduct(X.T, I), (1 2 3)), PermuteDims(ArrayTensorProduct(I, X), (1 2 3))) == (2 * X.T)
FAILED sympy/matrices/expressions/tests/test_derivatives.py::test_derivatives_matrix_norms - TypeError: 'Half' object cannot be interpreted as an integer
========================================================= 2 failed, 256 passed, 1 deselected, 6 xfailed in 16.65s ==========================================================

The test doesn't pass if I change matpow to inherit Pow though.

sylee957 commented 1 year ago

The failing is interesting because if I change the inheritance, the dispatching logic in ‘from_matrix_to_array’ breaks.

It means that the if-isinstance statements should be reordered in reverse of the inheritance. This looks reasonable, but this seems like a bug that can be hard to catch without accident