jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.53k stars 2.8k forks source link

Is it possible to calculate gradients through a parameterized `index_update`? #5245

Closed khiner closed 3 years ago

khiner commented 3 years ago

Hi all! I have a problem in the audio domain in which I am attempting to find, via gradient descent, which array index should be set to a certain (static) value. That is, I'm trying to optimize an array index itself, rather than values held within an array.

I have provided a toy example below that demonstrates the type of thing I'd like to do. grad seems to not track the relationship between the dynamic index and the array position when computing gradients. I assume this is because information is "lost" when the float parameter (index below) is cast to an int32 to do the indexing.

Can someone please help me understand more clearly: 1) why grad is unable to track the impact of index on the array X directly used in the loss calculation? 2) if there are any other potential solutions you could think of for using jax to estimate which array indices should be set to a certain value? (E.g. some clever way of converting a floating point value into an array mask in a way that's differentiable? a custom VJP?)

For context, in my real application, the array-index values themselves are not directly being optimized, but rather are derived from other parameters. I am estimating many parameters in a tree via gradient descent, some of which can change the values of integers indexing into arrays.

Thanks for any help here!

import jax.numpy as jnp
from jax import value_and_grad, jit
from jax.ops import index_update

def parameterized_index_update(index, X):
    Y = jnp.arange(5)
    # This works (finds gradient):
    # X = jnp.array([0, 1, index, 3, 4])
    # This doesn't work (gradient is always zero):
    X = index_update(X, index.astype('int32'), 2.0)
    return ((Y - X) ** 2).mean()

X = jnp.arange(5)
grad_fn = jit(value_and_grad(parameterized_index_update))
# Optimization goal: find the index to set to 2 (Answer: index == 2)
index = 4.0
for train_i in range(5):
    loss, grad = grad_fn(index, X)
    print('Loss: {:.2f}, Grad: {:.2f}'.format(loss, grad))
    index -= grad

Output:

Loss: 0.80, Grad: 0.00
Loss: 0.80, Grad: 0.00
Loss: 0.80, Grad: 0.00
Loss: 0.80, Grad: 0.00
Loss: 0.80, Grad: 0.00
jakevdp commented 3 years ago

Hi - I think the issue here is that gradients are fundamentally a function of real values, where you're trying to take the gradient with respect to an integer. When you cast to int32, it effectively creates a step function, which has a zero gradient everywhere.

Using JAX's auto-differentiation machinery for a task like this will likely require converting your discrete function into some kind of smooth approximation.