pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
325 stars 97 forks source link

BUG: `pt.linalg.solve` returns incorrect results when `mode = "NUMBA"` #422

Open jessegrabowski opened 1 year ago

jessegrabowski commented 1 year ago

Describe the issue:

This is identical to #382. When assume_a != 'gen' in pt.linalg.solve, the solve_triangular function is incorrectly invoked, and the resulting output is not correct.

Reproducable code example:

import pytensor
import pytensor.tensor as pt
import numpy as np

X = pt.dmatrix('X')
b = pt.dmatrix('b')
y1 = pt.linalg.solve(X, b, assume_a='gen')
y2 = pt.linalg.solve(X, b, assume_a='sym')

f1 = pytensor.function([X, b], y1)
f2 = pytensor.function([X, b], y2)
f1_nb = pytensor.function([X, b], y1, mode='NUMBA')
f2_nb = pytensor.function([X, b], y2, mode='NUMBA')

X_sym = np.random.normal(size=(3, 3))
X_sym = X_sym @ X_sym.T
X_inv_1 = f1(X_sym, np.eye(3))
X_inv_2 = f2(X_sym, np.eye(3))
X_inv_1_nb = f1_nb(X_sym, np.eye(3))
X_inv_2_nb = f2_nb(X_sym, np.eye(3))

# Passes, C backend, assume_a = 'gen'
np.testing.assert_allclose(X_inv_1 @ X_sym, np.eye(3), atol=1e-12)

# Passes, C backend, assume_a = 'sym'
np.testing.assert_allclose(X_inv_2 @ X_sym, np.eye(3), atol=1e-12)

# Passes, Numba backend, assume_a = 'gen'
np.testing.assert_allclose(X_inv_1_nb @ X_sym, np.eye(3), atol=1e-12)

# Fails, Numba backend, assuma_a = 'sym'
np.testing.assert_allclose(X_inv_2_nb @ X_sym, np.eye(3), atol=1e-12)

Error message:

```shell AssertionError: Not equal to tolerance rtol=1e-07, atol=1e-12 Mismatched elements: 6 / 9 (66.7%) Max absolute difference: 0.67369019 Max relative difference: 0.52399333 x: array([[ 4.760067e-01, -2.459194e-01, 9.939378e-18], [-3.325732e-01, 6.636944e-01, 1.353323e-18], [-1.317192e-01, 6.736902e-01, 1.000000e+00]]) y: array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) ```

PyTensor version information:

Pytensor version 2.11.1

Context for the issue:

No response

ricardoV94 commented 11 months ago

@jessegrabowski Is this one resolved?

jessegrabowski commented 11 months ago

No. I am working on a numba overload for solve but it's more work than I expected. Maybe in the meantime I could do a temporary quick fix that either errors if assume_a is not "gen", or silently force it to be?

ricardoV94 commented 11 months ago

Yes you can treat it as if it was gen always while we don't have the alternatives working