pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.3k stars 245 forks source link

Example for spike-and-slab regression #769

Closed fehiepsi closed 3 months ago

fehiepsi commented 4 years ago

Per a question asked by a user in the forum, it would be nice to have a tutorial/example for this type of regression. I searched for some examples available and found some nice ones below.

None of those examples marginalize discrete latent variables so this would be a good example for the enumeration mechanism in NumPyro. I would expect with marginalization, the result will be better than the methods in the above ones.

For ones who are interested in this issue, please ping me with any question you have.

xidulu commented 4 years ago

This example looks pretty interesting to me.

I believe the model could implemented like this ?

def model(X, y):
    N, d = X.shape
    pi = numpyro.param('pi', np.ones((d,)) * 0.5, constraint=unit_interval)
    xi = numpyro.sample('xi', dist.Bernoulli(probs=pi), infer={'enumerate': 'parallel'})
    w = numpyro.sample('w', dist.Normal(0, xi))
    return numpyro.sample('obs', dist.Normal(np.dot(X, w)), obs=y)

Haven't figured out how to write the inference part, is Enum_elbo available in Numpyro now?

fehiepsi commented 4 years ago

Hi @xidulu, I think you can use MCMC for this model. We haven't supported enum for SVI yet. Regarding the model, I think it does not make sense to use xi for scale of Normal distribution (e.g. scale can't take 0 values). Your model looks a bit different from spike and slab regression. I guess there is a typo or something?

xidulu commented 4 years ago

Hi @fehiepsi My knowledge on spike and slab prior mostly comes from : Eq 11 in https://arxiv.org/pdf/1810.04045.pdf

Screen Shot 2020-10-10 at 3 40 37 AM Screen Shot 2020-10-10 at 3 42 44 AM

I will take a look at the more authentic spike and slab prior in the paper you linked and polish my code.

fehiepsi commented 4 years ago

Thanks @xidulu! Now, I can see what you meant. Unfortunately, there is a limitation in NumPyro that Normal(0, xi) can't represent the Delta probability when xi=0. You can get around that issue by using w = xi * Normal(0, 1).

xidulu commented 4 years ago

@fehiepsi

def model(X, y):
    N, d = X.shape
    with numpyro.plate('d', d):
        pi = numpyro.sample('pi', dist.Beta(0.5, 0.5))
        xi = numpyro.sample('xi', dist.Bernoulli(probs=pi))
        w = numpyro.sample('w', dist.Normal(0., 10.0))
    return numpyro.sample('obs', dist.Normal(np.dot(X, w * xi)), obs=y)

It seems that the code above is incompatible with the enumeration mechanism:

<ipython-input-31-1286a6ff5a73> in model(X, y)
      5         xi = numpyro.sample('xi', dist.Bernoulli(probs=pi))
      6         w = numpyro.sample('w', dist.Normal(0., 10.0))
----> 7     return numpyro.sample('obs', dist.Normal(np.dot(X, w * xi)), obs=y)

~/anaconda3/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in dot(a, b, precision)
   2834     return lax.mul(a, b)
   2835   if _max(a_ndim, b_ndim) <= 2:
-> 2836     return lax.dot(a, b, precision=precision)
   2837 
   2838   if b_ndim == 1:

~/anaconda3/lib/python3.8/site-packages/jax/lax/lax.py in dot(lhs, rhs, precision)
    584                        precision=precision)
    585   else:
--> 586     raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
    587         lhs.shape, rhs.shape))
    588 

TypeError: Incompatible shapes for dot: got (100, 100) and (2, 100).

I believe the reason behind is that, xi.shape becomes (2, 1) during the enumerate_support, causing the shape of xi * Normal(0, 1) to go wrong.


p.s. Code for generating toy data:

def generate_data(key, n_samples=100, n_features=100, features_kept=10):
    X = random.normal(key, (n_samples, n_features))
    # Create weights with a precision lambda_ of 4.
    lambda_ = 4.
    w = np.zeros(n_features)
    # Only keep 10 weights of interest
    relevant_features = random.choice(
        key, np.arange(n_features), features_kept, False)
    w = jax.ops.index_update(
        w,
        relevant_features,
        dist.Normal(loc=0, scale=1. / np.sqrt(lambda_)).sample(key, (features_kept,))
    )
    # Create noise with a precision alpha of 50.
    alpha_ = 50.
    noise = dist.Normal(loc=0, scale=1. / np.sqrt(alpha_)).sample(key, (n_samples,))
    # Create the target
    y = np.dot(X, w) + noise
    return X, y, w 
fehiepsi commented 4 years ago

@xidulu Under enumeration, one restriction is to write your code such that it is compatible with batch dimensions. In particular, the matrix-vector product np.dot(X, w * xi) will not compatible with batched "w xi" (in jax/numpy, matrix-vector will become matrix-matrix product if the vector has batch dimensions). Usually, for "batched" matrix-vector product, we will use `np.matmul(X, (w xi)[..., None]).squeeze(-1)`.

But there is also a bug in our funsor code. I'll make a fix for it soon. After that, your model should run under the enumeration mechanism. Thanks for the report! I am looking forward to seeing your tutorial.

eb8680 commented 4 years ago

But there is also a bug in our funsor code. I'll make a fix for it soon. After that, your model should run under the enumeration mechanism.

Unfortunately, enumeration is working as intended - the failure is caused by the fact that @xidulu's model cannot be used with exact enumeration, since the computational complexity of enumerating over the mixture variables is exponential in d. See this section of Pyro's enumeration tutorial for more discussion.

xidulu commented 4 years ago

@eb8680

Does it mean that, discrete variables (if enumerated) cannot be used inside the plate ?

Actually, when I print out the shape of xi, it was (2, 1), which does not contain the shape of the plate.

eb8680 commented 4 years ago

Does it mean that, discrete variables (if enumerated) cannot be used inside the plate ?

No, not exactly - it means that independent discrete variables enumerated inside of a plate cannot be used later outside that plate, because that would mean they are not really independent and we would have to enumerate over all 2^plate_size combinations of their values.

xidulu commented 4 years ago

@eb8680 I also tried moving xi outside the plate, however, the shape of xi printed out is still (2, 1)

def model(X, y):
    N, d = X.shape
    with numpyro.plate('d', d):
        pi = numpyro.sample('pi', dist.Beta(0.5, 0.5))
        w = numpyro.sample('w', dist.Normal(0., 10.0))
    xi = numpyro.sample('xi', dist.Bernoulli(probs=pi))
    # print(pi.shape)
    # print(xi.shape)
    return numpyro.sample('obs', dist.Normal(np.dot(X, w * xi)), obs=y)
fehiepsi commented 4 years ago

@xidulu The following code might run with #778

def model(X, y):
    N, d = X.shape
    with numpyro.plate('d', d, dim=-1):
        pi = numpyro.sample('pi', dist.Beta(0.5, 0.5))
        xi = numpyro.sample('xi', dist.Bernoulli(probs=pi))
        w = numpyro.sample('w', dist.Normal(0., 10.0))
    loc = np.matmul(X, (w * xi)[..., None]).squeeze(-1)
    with numpyro.plate('N', N, dim=-1):
        return numpyro.sample('obs', dist.Normal(loc), obs=y)

but as @eb8680 pointed out, we need a sequential plate for xi. As you observed: xi.shape = (2, 1), which means that xi[0] is all 0s (after broadcasting to become a d-length vector), and xi[1] is all 1s. What we need is a d-length sequence of 0/1 values (there is 2^d such sequences). Sorry for making confusion, I am thinking about a solution...

eb8680 commented 4 years ago

but as @eb8680 pointed out, we need a sequential plate for xi.

This would work in the sense that enumeration would produce the correct factors, at least for d < 25 or so when the maximum number of tensor dimensions in JAX is reached. To be clear, though, the underlying problem is mathematical, and exactly integrating out the xis in spike-and-slab regression still has cost exponential in d - the observation factor couples all of the discrete variables.

xidulu commented 4 years ago

@eb8680

Now I get what you mean, enumerating a d dimensional Bernoulli vector would inevitably cost O(2^d) right?

eb8680 commented 4 years ago

enumerating a d dimensional Bernoulli vector would inevitably cost O(2^d) right?

Yes, that's right. In our implementation of enumeration in Pyro (and NumPyro) we've chosen to have programs fail and raise an error in this situation rather than attempt to instantiate such a large tensor, although enumeration in NumPyro is still experimental and is missing the more interpretable error messages generated in Pyro.

For more on the restrictions on plate structure in models with enumerated discrete variables, see Pyro's enumeration tutorial, and for background on the math behind discrete variable elimination in Pyro see our ICML 2019 paper Tensor Variable Elimination for Plated Factor Graphs.

xidulu commented 4 years ago

@fehiepsi

I just read https://github.com/pyro-ppl/numpyro/issues/779 , I guess the current "work around" would be to manually declare the plate's dimension (tensor.expand) combined with a for loop, right?

fehiepsi commented 4 years ago

@xidulu There is still a bug that I created during one of the recent PRs. :( I will try to fix it this weekend, then will let you know.

fehiepsi commented 4 years ago

@xidulu Sorry, I forgot to mention that the bug has been fixed. The model is as follows

def model(X, y):
    N, d = X.shape
    with numpyro.plate('d', d, dim=-1):
        pi = numpyro.sample('pi', dist.Beta(0.5, 0.5))
        w = numpyro.sample('w', dist.Normal(0., 10.0))

    ws = []
    for k in range(d):
        xi_k = numpyro.sample(f'xi_{k}', dist.Bernoulli(probs=pi[k]))
        w_k = np.where(xi_k, w[k:k + 1], 0) 
        ws.append(w_k)
    ws = np.concatenate(np.broadcast_arrays(*ws), axis=-1)

    loc = np.einsum("...nd,...d->...n", X, ws)
    with numpyro.plate('N', N, dim=-1):
        return numpyro.sample('obs', dist.Normal(loc), obs=y)

But using enumeration for this model is pretty slow. :( I am trying to see if this mixed hmc algorithm does a better job here but it might take a while to have an implementation.

jotsif commented 3 years ago

@fehiepsi Do you still think the mixed HMC merged in https://github.com/pyro-ppl/numpyro/pull/826 will fix this? Have you tried it?

fehiepsi commented 3 years ago

Hi @jotsif, both DiscreteHMCGibbs and MixedHMC should work for a spike-and-slab model. Please see the examples in docstring of those classes for how to perform inference. In the above comments, we discussed enumeration, which is a different (working but exponentially slow) approach.

d-diaz commented 3 years ago

FYI, good post with examples of using a horseshoe prior for spike and slab in numpyro here: https://james-brennan.github.io/posts/horseshoe/

Also nice that the author compares horseshoe with Bayesian Ridge and Lasso regression models. I tried to fiddle with the same method on one of my own projects and ran into a good number of divergences. I worked for a bit to try non-centered flavors of the model but didn't have much luck. Unfortunately, the author didn't include the summary outputs from the MCMC, so not sure whether they had the same issue without trying to reproduce it.

jotsif commented 3 years ago

@d-diaz The horseshoe is another sparsity inducing prior compared to the spike and slab, avoiding the enumeration problem discussed in this ticket.

martinjankowiak commented 3 years ago

@d-diaz fyi that post has an error (it should read horseshoe_sigma = Tau * Lambda not horseshoe_sigma = Tau**2*Lambda**2). this should probably still produce reasonable results in most cases but this is not the horseshoe prior (it's something even sparser).

there's now a horseshoe example here

fehiepsi commented 3 months ago

Closed because we have supported some mcmc methods like HMCDiscreteGibbs and had examples in the documentation of those methods.