Open AdrienCorenflos opened 4 years ago
For reference, all the operations can be found here:
I think the easiest way to make this compile would be to use the three-argument jnp.where
function to replace invalid entries with zeros. It might look like this:
cov_valid = jnp.logical_and(valid[:, None], valid[None, :])
n_valid = valid.sum()
# First select the valid data only
y_valid = jnp.where(valid, y, 0)
H_valid = jnp.where(valid[:, None], H, 0)
C_y_valid = jnp.where(cov_valid, C, 0)
# Then compute some stuff
y_pred = jnp.dot(H_valid, x)
cov_pred = jnp.dot(H_valid, jnp.dot(C_x, H_valid.T)) + C_y_valid
Since everything you are doing is a linear product, zeroing-out the rows should be equivalent to masking them.
Hi,
Thanks for that. I actually am not doing a simple linear product sadly.
The next step is to basically compute the inverse of this cov_pred
(basically do something like linalg.solve(cov_pred, H_valid @ some_matrix, sym_pos=True)
, the whole calculation is in the wiki link I added.
So the above wouldn't do (cov_pred
would be singular).
Is there a plan/target for support of masked arrays or something similar as in tensorflow with their wildcard style shape allowing for runtime varying shape? I don't believe I can do without it and I'd need to know if I need to change the backend before I write more logic.
@jakevdp do you know then if this will be supported at some point in the foreseeable future? I've already rolled up my sleeves and made my code backend (jax/tensorflow/numpy) independent, but I'm still interested in using JAX mostly.
@AdrienCorenflos did you get any further? This limitation is a real showstopper for me for moving my backend to Jax from the likes of TF/PyTorch.
Hi,
I know there are already a bunch of issues open related to this, but I still am fairly confused so as to what to do. I have a fairly standard problem which can be formulated in say the following way:
p(x) = N(x; Ax, C_x) p(y|x) = N(y; Hx, C_y)
I'm trying to compute the posterior p(x|y). This involves a bunch of standard linear algebra operations, but say I only need to do the following:
The above fails due to the shape of
y_valid
etc not being known at compile time. I however can't do without JIT as I'm using this from within a scan operation.I see there is an (obscure to me) function mask that sounds like it would do the job. Am I right? if so, how can I apply it here?
Thanks,
Adrien