Closed fehiepsi closed 3 months ago
This example looks pretty interesting to me.
I believe the model could implemented like this ?
def model(X, y):
N, d = X.shape
pi = numpyro.param('pi', np.ones((d,)) * 0.5, constraint=unit_interval)
xi = numpyro.sample('xi', dist.Bernoulli(probs=pi), infer={'enumerate': 'parallel'})
w = numpyro.sample('w', dist.Normal(0, xi))
return numpyro.sample('obs', dist.Normal(np.dot(X, w)), obs=y)
Haven't figured out how to write the inference part, is Enum_elbo
available in Numpyro now?
Hi @xidulu, I think you can use MCMC for this model. We haven't supported enum for SVI yet. Regarding the model, I think it does not make sense to use xi
for scale of Normal distribution (e.g. scale can't take 0 values). Your model looks a bit different from spike and slab regression. I guess there is a typo or something?
Hi @fehiepsi My knowledge on spike and slab prior mostly comes from : Eq 11 in https://arxiv.org/pdf/1810.04045.pdf
I will take a look at the more authentic spike and slab prior in the paper you linked and polish my code.
Thanks @xidulu! Now, I can see what you meant. Unfortunately, there is a limitation in NumPyro that Normal(0, xi) can't represent the Delta probability when xi=0. You can get around that issue by using w = xi * Normal(0, 1).
@fehiepsi
def model(X, y):
N, d = X.shape
with numpyro.plate('d', d):
pi = numpyro.sample('pi', dist.Beta(0.5, 0.5))
xi = numpyro.sample('xi', dist.Bernoulli(probs=pi))
w = numpyro.sample('w', dist.Normal(0., 10.0))
return numpyro.sample('obs', dist.Normal(np.dot(X, w * xi)), obs=y)
It seems that the code above is incompatible with the enumeration mechanism:
<ipython-input-31-1286a6ff5a73> in model(X, y)
5 xi = numpyro.sample('xi', dist.Bernoulli(probs=pi))
6 w = numpyro.sample('w', dist.Normal(0., 10.0))
----> 7 return numpyro.sample('obs', dist.Normal(np.dot(X, w * xi)), obs=y)
~/anaconda3/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in dot(a, b, precision)
2834 return lax.mul(a, b)
2835 if _max(a_ndim, b_ndim) <= 2:
-> 2836 return lax.dot(a, b, precision=precision)
2837
2838 if b_ndim == 1:
~/anaconda3/lib/python3.8/site-packages/jax/lax/lax.py in dot(lhs, rhs, precision)
584 precision=precision)
585 else:
--> 586 raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
587 lhs.shape, rhs.shape))
588
TypeError: Incompatible shapes for dot: got (100, 100) and (2, 100).
I believe the reason behind is that, xi.shape
becomes (2, 1)
during the enumerate_support
, causing the shape of xi * Normal(0, 1)
to go wrong.
p.s. Code for generating toy data:
def generate_data(key, n_samples=100, n_features=100, features_kept=10):
X = random.normal(key, (n_samples, n_features))
# Create weights with a precision lambda_ of 4.
lambda_ = 4.
w = np.zeros(n_features)
# Only keep 10 weights of interest
relevant_features = random.choice(
key, np.arange(n_features), features_kept, False)
w = jax.ops.index_update(
w,
relevant_features,
dist.Normal(loc=0, scale=1. / np.sqrt(lambda_)).sample(key, (features_kept,))
)
# Create noise with a precision alpha of 50.
alpha_ = 50.
noise = dist.Normal(loc=0, scale=1. / np.sqrt(alpha_)).sample(key, (n_samples,))
# Create the target
y = np.dot(X, w) + noise
return X, y, w
@xidulu Under enumeration, one restriction is to write your code such that it is compatible with batch dimensions. In particular, the matrix-vector product np.dot(X, w * xi)
will not compatible with batched "w xi" (in jax/numpy, matrix-vector will become matrix-matrix product if the vector has batch dimensions). Usually, for "batched" matrix-vector product, we will use `np.matmul(X, (w xi)[..., None]).squeeze(-1)`.
But there is also a bug in our funsor code. I'll make a fix for it soon. After that, your model should run under the enumeration mechanism. Thanks for the report! I am looking forward to seeing your tutorial.
But there is also a bug in our funsor code. I'll make a fix for it soon. After that, your model should run under the enumeration mechanism.
Unfortunately, enumeration is working as intended - the failure is caused by the fact that @xidulu's model cannot be used with exact enumeration, since the computational complexity of enumerating over the mixture variables is exponential in d
. See this section of Pyro's enumeration tutorial for more discussion.
@eb8680
Does it mean that, discrete variables (if enumerated) cannot be used inside the plate ?
Actually, when I print out the shape of xi
, it was (2, 1)
, which does not contain the shape of the plate.
Does it mean that, discrete variables (if enumerated) cannot be used inside the plate ?
No, not exactly - it means that independent discrete variables enumerated inside of a plate cannot be used later outside that plate, because that would mean they are not really independent and we would have to enumerate over all 2^plate_size
combinations of their values.
@eb8680
I also tried moving xi
outside the plate, however, the shape of xi
printed out is still (2, 1)
def model(X, y):
N, d = X.shape
with numpyro.plate('d', d):
pi = numpyro.sample('pi', dist.Beta(0.5, 0.5))
w = numpyro.sample('w', dist.Normal(0., 10.0))
xi = numpyro.sample('xi', dist.Bernoulli(probs=pi))
# print(pi.shape)
# print(xi.shape)
return numpyro.sample('obs', dist.Normal(np.dot(X, w * xi)), obs=y)
@xidulu The following code might run with #778
def model(X, y):
N, d = X.shape
with numpyro.plate('d', d, dim=-1):
pi = numpyro.sample('pi', dist.Beta(0.5, 0.5))
xi = numpyro.sample('xi', dist.Bernoulli(probs=pi))
w = numpyro.sample('w', dist.Normal(0., 10.0))
loc = np.matmul(X, (w * xi)[..., None]).squeeze(-1)
with numpyro.plate('N', N, dim=-1):
return numpyro.sample('obs', dist.Normal(loc), obs=y)
but as @eb8680 pointed out, we need a sequential plate for xi
. As you observed: xi.shape = (2, 1)
, which means that xi[0]
is all 0s (after broadcasting to become a d-length vector), and xi[1]
is all 1s. What we need is a d-length sequence of 0/1 values (there is 2^d such sequences). Sorry for making confusion, I am thinking about a solution...
but as @eb8680 pointed out, we need a sequential plate for
xi
.
This would work in the sense that enumeration would produce the correct factors, at least for d < 25
or so when the maximum number of tensor dimensions in JAX is reached. To be clear, though, the underlying problem is mathematical, and exactly integrating out the xi
s in spike-and-slab regression still has cost exponential in d
- the observation factor couples all of the discrete variables.
@eb8680
Now I get what you mean, enumerating a d dimensional Bernoulli vector would inevitably cost O(2^d) right?
enumerating a d dimensional Bernoulli vector would inevitably cost O(2^d) right?
Yes, that's right. In our implementation of enumeration in Pyro (and NumPyro) we've chosen to have programs fail and raise an error in this situation rather than attempt to instantiate such a large tensor, although enumeration in NumPyro is still experimental and is missing the more interpretable error messages generated in Pyro.
For more on the restrictions on plate structure in models with enumerated discrete variables, see Pyro's enumeration tutorial, and for background on the math behind discrete variable elimination in Pyro see our ICML 2019 paper Tensor Variable Elimination for Plated Factor Graphs.
@fehiepsi
I just read https://github.com/pyro-ppl/numpyro/issues/779 , I guess the current "work around" would be to manually declare the plate's dimension (tensor.expand) combined with a for loop, right?
@xidulu There is still a bug that I created during one of the recent PRs. :( I will try to fix it this weekend, then will let you know.
@xidulu Sorry, I forgot to mention that the bug has been fixed. The model is as follows
def model(X, y):
N, d = X.shape
with numpyro.plate('d', d, dim=-1):
pi = numpyro.sample('pi', dist.Beta(0.5, 0.5))
w = numpyro.sample('w', dist.Normal(0., 10.0))
ws = []
for k in range(d):
xi_k = numpyro.sample(f'xi_{k}', dist.Bernoulli(probs=pi[k]))
w_k = np.where(xi_k, w[k:k + 1], 0)
ws.append(w_k)
ws = np.concatenate(np.broadcast_arrays(*ws), axis=-1)
loc = np.einsum("...nd,...d->...n", X, ws)
with numpyro.plate('N', N, dim=-1):
return numpyro.sample('obs', dist.Normal(loc), obs=y)
But using enumeration for this model is pretty slow. :( I am trying to see if this mixed hmc algorithm does a better job here but it might take a while to have an implementation.
@fehiepsi Do you still think the mixed HMC merged in https://github.com/pyro-ppl/numpyro/pull/826 will fix this? Have you tried it?
Hi @jotsif, both DiscreteHMCGibbs and MixedHMC should work for a spike-and-slab model. Please see the examples in docstring of those classes for how to perform inference. In the above comments, we discussed enumeration, which is a different (working but exponentially slow) approach.
FYI, good post with examples of using a horseshoe prior for spike and slab in numpyro here: https://james-brennan.github.io/posts/horseshoe/
Also nice that the author compares horseshoe with Bayesian Ridge and Lasso regression models. I tried to fiddle with the same method on one of my own projects and ran into a good number of divergences. I worked for a bit to try non-centered flavors of the model but didn't have much luck. Unfortunately, the author didn't include the summary outputs from the MCMC, so not sure whether they had the same issue without trying to reproduce it.
@d-diaz The horseshoe is another sparsity inducing prior compared to the spike and slab, avoiding the enumeration problem discussed in this ticket.
@d-diaz fyi that post has an error (it should read horseshoe_sigma = Tau * Lambda
not horseshoe_sigma = Tau**2*Lambda**2
). this should probably still produce reasonable results in most cases but this is not the horseshoe prior (it's something even sparser).
there's now a horseshoe example here
Closed because we have supported some mcmc methods like HMCDiscreteGibbs and had examples in the documentation of those methods.
Per a question asked by a user in the forum, it would be nice to have a tutorial/example for this type of regression. I searched for some examples available and found some nice ones below.
None of those examples marginalize discrete latent variables so this would be a good example for the enumeration mechanism in NumPyro. I would expect with marginalization, the result will be better than the methods in the above ones.
For ones who are interested in this issue, please ping me with any question you have.