pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Got runtime error when using hmc / mcmc together with sequential enumeration #3343

Closed ljlin closed 1 month ago

ljlin commented 7 months ago

I am trying to implement a CRBD model, which contains both continuous and discrete random variables, described in this paper.-,Algorithm%202,-Basic%20birth%2Ddeath) and apply HMC to it.

But I got a runtime error that says

ValueError: Continuous inference cannot handle discrete sample site './isSpeciation'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['./isSpeciation'])).

Is this I use Pyro's HMC in the wrong way, or it's Pyro's HMC not compatible with {'enumerate': 'sequential'}?

What should I do to apply HMC together with {'enumerate': 'sequential'} to this CRBD model?

Thanks for helping.

Code:

import argparse
import pyro.distributions as dist
import pyro
import torch
from pyro.infer import MCMC, NUTS
import sys
sys. setrecursionlimit(32767)

def gosExtince(prefix, time, la, mu):
    waitingTime = pyro.sample(f"{prefix}/waitingTime", dist.Exponential(la))
    if waitingTime > time:
        b_waitingTime = False
    else:
        isSpeciation = pyro.sample(f"{prefix}/isSpeciation", dist.Bernoulli(la / (la + mu)), infer={'enumerate': 'sequential'})
        # ValueError: Continuous inference cannot handle discrete sample site './isSpeciation'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['./isSpeciation'])).
        if isSpeciation: # https://pyro.ai/examples/enumeration.html
            x = gosExtince(f"{prefix}/x", time - waitingTime, la, mu)
            y = gosExtince(f"{prefix}/y", time - waitingTime, la, mu)
            b_isSpeciation = x and y
        else:
            b_isSpeciation = True
        b_waitingTime = b_isSpeciation
    return b_waitingTime

def model(time):
    la = pyro.sample("lamda", dist.Gamma(1, 1))
    mu = pyro.sample("mu", dist.Gamma(1, 1))
    obs = gosExtince(".", time, la, mu)
    pyro.factor("obs", torch.ones(1) if obs else -torch.inf)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', choices=["HMC", "MAPPL"],  required=True)
    parser.add_argument('--time', type=float, required=True)
    parser.add_argument('--progress_bar', action='store_true')
    parser.add_argument('--num_chains', type=int, required=True)
    parser.add_argument('--warmup_steps', type=int, required=True)
    parser.add_argument('--num_samples', type=int, required=True)
    args = parser.parse_args()
    print(args)

    if args.config == "HMC":
        nuts_kernel = NUTS(model)

    mcmc = MCMC(
        nuts_kernel,
        warmup_steps=args.warmup_steps,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        disable_progbar=not args.progress_bar
    )
    mcmc.run(
        args.time
    )
    mcmc.print_summary()

if __name__ == '__main__':
    main()
gizemcaylak commented 2 months ago

Hi, were you able to find a solution to this issue? I am having similar issues with a phylogenetics model

fehiepsi commented 1 month ago

HMC does not work with discrete latent variables. (with enumeration, discrete latent variables will be marginalized out so HMC can still work.)

gizemcaylak commented 1 month ago

HMC does not work with discrete latent variables. (with enumeration, discrete latent variables will be marginalized out so HMC can still work.)

Thanks. Do you have an idea on which inference algorithm in Numpyro would work if I want to sample discrete variables? As far as I have tried SVI doesn’t work

fehiepsi commented 1 month ago

I guess you could try DiscreteHMCGibbs.

fehiepsi commented 1 month ago

Please use our forum https://forum.pyro.ai/ for questions. We mainly use github for tracking bugs and feature requests.