jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Unimplemented: binary integer op 'power' #48

Closed iandanforth closed 5 years ago

iandanforth commented 5 years ago
def square(x):
  return x**2

val = 3
dfn = grad(square)
print(dfn(val))

I was surprised this threw an error. Changing it to val = 3.0 works as expected.

mattjj commented 5 years ago

This seems like an important operation to offer :)

I think the binary integer pow function isn't implemented in XLA, since we seem to be getting this error and the kPow opcode doesn't seem to appear in that list. Maybe just an oversight. I'll follow up with XLA folks.

mattjj commented 5 years ago

A second issue here is that grad should raise an error on non-floating argument types.

mattjj commented 5 years ago

@hawkinsp guessed that XLA's Pow HLO is meant to model std::pow, which apparently doesn't work on integer values either. We should solve this in JAX at the jax.numpy level.

hawkinsp commented 5 years ago

Thanks for the issue report!

I added support for integer powers to jax.numpy.power.

I also filed https://github.com/google/jax/issues/424 for raising an error if taking grad of an integer.