Currently to add, e.g., a diagonal linear operator, to a matrix we would do the following:
from jaxlinop import DiagonalLinearOperator
import jax.numpy as jnp
M = jnp.array([[1.,2.],[3., 4.]])
D = DiagonalLinearOperator(jnp.array([1.,2.]))
res = D + M
print(res.to_dense())
[[2. 2.]
[3. 6.]]
This is all very nice. But if we tried to use jnp.add instead, this would result in an error.
res = jnp.add(D, M)
TypeError: add requires ndarray or scalar arguments, got <class 'jaxlinop.diagonal_linear_operator.DiagonalLinearOperator'> at position 0.
If would be nice to overload jnp.add, so that it returned D + M instead. This would be particularly nice for matrix solves, using jnp.solve(D, M) would be more familiar to JAX users than D.solve(M)
Proposed solution
Overload jax.numpy operations, but via activated by a global config.
from jaxlinop import config
config.update("overload_jax_numpy", True)
res = jnp.add(D, M)
print(res.to_dense())
The problem:
Currently to add, e.g., a diagonal linear operator, to a matrix we would do the following:
This is all very nice. But if we tried to use
jnp.add
instead, this would result in an error.If would be nice to overload
jnp.add
, so that it returnedD + M
instead. This would be particularly nice for matrix solves, usingjnp.solve(D, M)
would be more familiar to JAX users thanD.solve(M)
Proposed solution
Overload
jax.numpy
operations, but via activated by a global config.