patrick-kidger / quax

Multiple dispatch over abstract array types in JAX.
Apache License 2.0
102 stars 3 forks source link

add `__pow__` #6

Closed nstarman closed 10 months ago

nstarman commented 10 months ago

DenseArrayValue could be raised to an integer power.

Basic test:

import jax
import jax.numpy as jnp
from quax import DenseArrayValue, quaxify

x = jnp.array([1, 2, 3], dtype=jnp.float64)
DenseArrayValue(x) ** 2
> TypeError: unsupported operand type(s) for ** or pow(): 'DenseArrayValue' and 'int'

Fair enough. Looking at ArrayValue, there's no __pow__ defined. If I add that..

class ArrayValue(Value):
    ...
    __add__ = quaxify(operator.add)
    ...
    __pow__ = quaxify(operator.pow)
DenseArrayValue(x) ** 2
> Array([1., 4., 9.], dtype=float64)
patrick-kidger commented 10 months ago

Thankyou! Can you add __rpow__ as well?

nstarman commented 10 months ago

Can do! I'm trying to think when this would be useful... Perhaps something like

2 ** quax.zero.Zero(3)
patrick-kidger commented 10 months ago

Alright, LGTM! Thank you for adding this. :)