DenseArrayValue could be raised to an integer power.
Basic test:
import jax
import jax.numpy as jnp
from quax import DenseArrayValue, quaxify
x = jnp.array([1, 2, 3], dtype=jnp.float64)
DenseArrayValue(x) ** 2
> TypeError: unsupported operand type(s) for ** or pow(): 'DenseArrayValue' and 'int'
Fair enough. Looking at ArrayValue, there's no __pow__ defined. If I add that..
class ArrayValue(Value):
...
__add__ = quaxify(operator.add)
...
__pow__ = quaxify(operator.pow)
DenseArrayValue could be raised to an integer power.
Basic test:
Fair enough. Looking at
ArrayValue
, there's no__pow__
defined. If I add that..