pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.14k stars 235 forks source link

ZeroInflatedDistribution Errors with continuous/discrete distributions #1177

Closed EdwardRaff closed 2 years ago

EdwardRaff commented 2 years ago

If I build a model as

def model(X, Y):
  print("Start:", X.shape, Y.shape)

  N = Y.shape[0]

  with numpyro.plate('N', N):
    v = sample('v', dist.ZeroInflatedDistribution(dist.BinomialProbs(0.5), gate=0.5)) #This errors out and changes the shape? 

  return

SEED = 1337
args = {'algo': 'NUTS', 'num_warmup': 300*1, 'num_samples':450*1, 'num_chains':1}
rng_key, rng_key_predict = random.split(random.PRNGKey(SEED))

kernel = DiscreteHMCGibbs(NUTS(model), modified=True)
mcmc = MCMC(
    kernel,
    num_warmup=args['num_warmup'],
    num_samples=args['num_samples'],
    num_chains=args['num_chains'],
    progress_bar=True,
)
mcmc.run(rng_key, X_big, Y_big)

I get the confusing error mesage:

    441         assert (
    442             self._gibbs_sites
--> 443         ), "Cannot detect any discrete latent variables in the model."
    444         return super().init(rng_key, num_warmup, init_params, model_args, model_kwargs)
    445 

AssertionError: Cannot detect any discrete latent variables in the model.

If I then switch to kernel =NUTS(model) I get this

RuntimeError: MCMC only supports continuous sites or discrete sites with enumerate support, but got ExpandedDistribution.

The errors get more confusing if I move toward model I want to perform, a zero inflated Beta

def model(X, Y):
  print("Start:", X.shape, Y.shape)

  N = Y.shape[0]

  with numpyro.plate('N', N):
    v = sample('v', dist.ZeroInflatedDistribution(dist.Beta(3,3), gate=0.5)) #This errors out and changes the shape? 

  return

Which apparently does not allow to zero inflated continuous distributions according to this assert?

    660         batch_shape = lax.broadcast_shapes(jnp.shape(gate), base_dist.batch_shape)
    661         (self.gate,) = promote_shapes(gate, shape=batch_shape)
--> 662         assert base_dist.is_discrete
    663         if base_dist.event_shape:
    664             raise ValueError(

AssertionError: 

My larger goal was to have a zero-one inflated Beta, but I would be happy to reach Zero-Inflated only

EdwardRaff commented 2 years ago

@fehiepsi I appreciate the quick bug fix! As I am reading the code changes, is my understanding correct that this would still not allow the ZeroInflatedDistribution to work with the Beta distribution? I'm trying to understand if the discere only distributions is intendended.

fehiepsi commented 2 years ago

I guess you can use inflated beta one for likelihood. We don't have an inference algorithm to deal with inflated beta latent variable.

EdwardRaff commented 2 years ago

Does that mean dist.ZeroInflatedDistribution(dist.Beta(3,3), gate=0.5, obs=Y) would work fine, but dist.ZeroInflatedDistribution(dist.Beta(3,3), gate=0.5) would not?

I'm currently doing something like:

prob_zero = sample('Zero Inflation', dist.BernoulliProbs(0.5))
beta_aug = sample('Slab Response', dist.BetaProportion(3, 3))
v = numpyro.deterministic('Zero Inflated Response', (1-prob_zero)*beta_aug)

Do you think that is OK / have any other recommendations?

Either way, want to make sure you know I appreciate Numpyro and your fast help! This has helped me get some cool stuff working with greater speed/ease then I ever could have done without it.

fehiepsi commented 2 years ago

Does that mean dist.ZeroInflatedDistribution(dist.Beta(3,3), gate=0.5, obs=Y) would work fine, but dist.ZeroInflatedDistribution(dist.Beta(3,3), gate=0.5) would not?

Yes.

Do you think that is OK / have any other recommendations?

Yup, I think that is the right way to be able to perform inference for a latent zero-inflated beta variable.

ecotner commented 2 years ago

Hi I am running into a similar error but have not been able to resolve it. I keep running into the assertion assert base_dist.support.is_discrete that makes it seem like the ZeroInflatedDistribution will not support continuous-valued distributions like the Beta. My model is:

def model(data):
    α = numpyro.sample("alpha", TruncatedNormal(0, 1, low=0))
    β = numpyro.sample("beta", TruncatedNormal(0, 1, low=0))
    p = numpyro.sample("p", Beta(1, 1))
    R = ZeroInflatedDistribution(Beta(α, β), gate=p)
    return numpyro.sample("rate", R, obs=data)

so the ZI site is not latent, which from the discussion above, sounds like it should be ok then? Does the ZI distribution only work with discrete distributions? I am using NUTS to start, but plan to start using SVI once I work out the kinks, btw, I don't know if that is relevant.