JaxGaussianProcesses / JaxLinOp

Linear Operators in JAX.
Apache License 2.0
12 stars 0 forks source link

feat: Overload jax.numpy operations for LinearOperator compatibility. #2

Open daniel-dodd opened 1 year ago

daniel-dodd commented 1 year ago

The problem:

Currently to add, e.g., a diagonal linear operator, to a matrix we would do the following:

from jaxlinop import DiagonalLinearOperator
import jax.numpy as jnp

M =  jnp.array([[1.,2.],[3., 4.]])
D = DiagonalLinearOperator(jnp.array([1.,2.]))

res = D + M

print(res.to_dense())
[[2. 2.]
 [3. 6.]]

This is all very nice. But if we tried to use jnp.add instead, this would result in an error.

res = jnp.add(D, M)
TypeError: add requires ndarray or scalar arguments, got <class 'jaxlinop.diagonal_linear_operator.DiagonalLinearOperator'> at position 0.

If would be nice to overload jnp.add, so that it returned D + M instead. This would be particularly nice for matrix solves, using jnp.solve(D, M) would be more familiar to JAX users than D.solve(M)

Proposed solution

Overload jax.numpy operations, but via activated by a global config.

from jaxlinop import config

config.update("overload_jax_numpy", True)

res = jnp.add(D, M)

print(res.to_dense())
[[2. 2.]
 [3. 6.]]