Closed khiner closed 3 years ago
Hi - I think the issue here is that gradients are fundamentally a function of real values, where you're trying to take the gradient with respect to an integer. When you cast to int32
, it effectively creates a step function, which has a zero gradient everywhere.
Using JAX's auto-differentiation machinery for a task like this will likely require converting your discrete function into some kind of smooth approximation.
Hi all! I have a problem in the audio domain in which I am attempting to find, via gradient descent, which array index should be set to a certain (static) value. That is, I'm trying to optimize an array index itself, rather than values held within an array.
I have provided a toy example below that demonstrates the type of thing I'd like to do.
grad
seems to not track the relationship between the dynamic index and the array position when computing gradients. I assume this is because information is "lost" when the float parameter (index
below) is cast to anint32
to do the indexing.Can someone please help me understand more clearly: 1) why
grad
is unable to track the impact ofindex
on the arrayX
directly used in the loss calculation? 2) if there are any other potential solutions you could think of for using jax to estimate which array indices should be set to a certain value? (E.g. some clever way of converting a floating point value into an array mask in a way that's differentiable? a custom VJP?)For context, in my real application, the array-index values themselves are not directly being optimized, but rather are derived from other parameters. I am estimating many parameters in a tree via gradient descent, some of which can change the values of integers indexing into arrays.
Thanks for any help here!
Output: