jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

How to handle boolean masks #4212

Open AdrienCorenflos opened 4 years ago

AdrienCorenflos commented 4 years ago

Hi,

I know there are already a bunch of issues open related to this, but I still am fairly confused so as to what to do. I have a fairly standard problem which can be formulated in say the following way:

p(x) = N(x; Ax, C_x) p(y|x) = N(y; Hx, C_y)

I'm trying to compute the posterior p(x|y). This involves a bunch of standard linear algebra operations, but say I only need to do the following:

valid = ~jnp.isnan(y)
cov_valid = valid[:, None] * valid[None, :]

n_valid = jnp.sum(valid.astype(jnp.int32)

# First select the valid data only
y_valid = y[valid]
H_valid = H[valid, :]
C_y_valid = C[cov_valid].reshape((n_valid, n_valid))

# Then compute some stuff
y_pred = jnp.dot(H_valid, x)
cov_pred = jnp.dot(H_valid, jnp.dot(C_x, H_valid.T)) + C_y_valid

The above fails due to the shape of y_valid etc not being known at compile time. I however can't do without JIT as I'm using this from within a scan operation.

I see there is an (obscure to me) function mask that sounds like it would do the job. Am I right? if so, how can I apply it here?

Thanks,

Adrien

AdrienCorenflos commented 4 years ago

For reference, all the operations can be found here:

https://en.wikipedia.org/wiki/Kalman_filter#Update

jakevdp commented 4 years ago

I think the easiest way to make this compile would be to use the three-argument jnp.where function to replace invalid entries with zeros. It might look like this:

cov_valid = jnp.logical_and(valid[:, None], valid[None, :])
n_valid = valid.sum()

# First select the valid data only
y_valid = jnp.where(valid, y, 0)
H_valid = jnp.where(valid[:, None], H, 0)
C_y_valid = jnp.where(cov_valid, C, 0)

# Then compute some stuff
y_pred = jnp.dot(H_valid, x)
cov_pred = jnp.dot(H_valid, jnp.dot(C_x, H_valid.T)) + C_y_valid

Since everything you are doing is a linear product, zeroing-out the rows should be equivalent to masking them.

AdrienCorenflos commented 4 years ago

Hi,

Thanks for that. I actually am not doing a simple linear product sadly. The next step is to basically compute the inverse of this cov_pred (basically do something like linalg.solve(cov_pred, H_valid @ some_matrix, sym_pos=True), the whole calculation is in the wiki link I added.

So the above wouldn't do (cov_pred would be singular).

AdrienCorenflos commented 4 years ago

Is there a plan/target for support of masked arrays or something similar as in tensorflow with their wildcard style shape allowing for runtime varying shape? I don't believe I can do without it and I'd need to know if I need to change the backend before I write more logic.

AdrienCorenflos commented 4 years ago

@jakevdp do you know then if this will be supported at some point in the foreseeable future? I've already rolled up my sleeves and made my code backend (jax/tensorflow/numpy) independent, but I'm still interested in using JAX mostly.

andreped commented 1 year ago

@AdrienCorenflos did you get any further? This limitation is a real showstopper for me for moving my backend to Jax from the likes of TF/PyTorch.