Closed iandanforth closed 5 years ago
This seems like an important operation to offer :)
I think the binary integer pow function isn't implemented in XLA, since we seem to be getting this error and the kPow
opcode doesn't seem to appear in that list. Maybe just an oversight. I'll follow up with XLA folks.
A second issue here is that grad
should raise an error on non-floating argument types.
@hawkinsp guessed that XLA's Pow HLO is meant to model std::pow
, which apparently doesn't work on integer values either. We should solve this in JAX at the jax.numpy
level.
Thanks for the issue report!
I added support for integer powers to jax.numpy.power
.
I also filed https://github.com/google/jax/issues/424 for raising an error if taking grad
of an integer.
I was surprised this threw an error. Changing it to
val = 3.0
works as expected.