Open jessegrabowski opened 7 months ago
I took out the bug label since it's more like a missing feature.
Is ProdWithoutZeros just pt.prod(x[pt.neq(x, 0)])
? In that case I would suggest dispatching to something like than in JAX. jax.numpy.prod
even accepts a where
argument already: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.prod.html
Although I am not sure it's gonna like that dynamic boolean array. Maybe It needs to be implemented as a Scan?
The graph JAX produces for grad of Prod is absurd? They just enumerate all the cases where a zero might be (with some bisect logic)?
import jax
import jax.numpy as jnp
def prod(x):
return jnp.prod(x)
# @jax.jit
def foo(x):
return jax.grad(prod)(x)
jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32"))
Maybe rewrite to something like this?
import jax
import jax.numpy as jnp
def prod(x):
return jnp.exp(jnp.sum(jnp.log(x)))
# @jax.jit
def foo(x):
return jax.grad(prod)(x)
jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32"))
Amusingly, this was suggested as a workaround here before the actual thing we have now was implemented.
That would fail with negative inputs, no?
Would prod(eq(x, 0), 1, x)
work?
Isn't this exactly the jax graph but without the bisect logic to avoid checking every single value in x
elemwise?
I don't see anywhere were jax is checking for zeros
The PR that implemented the grad is here: https://github.com/google/jax/pull/675/files
It's a bit odd, is that really better than just a switch statement?
In any case you are asking to rewrite a gradient, and unfortunately because grads are eager you don't know if this MulWithZeros is due to a grad of a prod or whatever other reason. So to be safe you would need to implement that exact Op dispatch.
Describe the issue:
The
ProdWithoutZeros
Op
arises in the gradients ofpt.prod
. This currently cannot be compiled to gradient mode unless we specifically passno_zeros_in_input=True
. I guess we would just need a JAX dispatch for this function? Or maybe a mapping to the correct jax.lax function?Reproducable code example:
Error message:
PyTensor version information:
Pytensor 2.17.4
Context for the issue:
I want the gradient of a product in JAX mode