pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 234 forks source link

Large potential energy while using `HMCGibbs` at the initial stage #1813

Closed disadone closed 3 months ago

disadone commented 3 months ago

Gibbs part is basically erased in the following code because it is not related to the problem. I just don't know why the vv_state.z_grad in hmc.py is extremely large.

https://github.com/pyro-ppl/numpyro/blob/0ba1306ea41c9865169360cb17c62efe7fc2bc94/numpyro/infer/hmc.py#L350

Here is the simplified version of the code. The acceptance rate keeps at 0. Is there any problem?

import jax
import jax.numpy as jnp
import jax.random as jnr
import jax.nn as nn
import numpyro
import numpyro.distributions as dist

key = jnr.PRNGKey(0)
ramp_ev=jnr.uniform(key, (800,100))
ramp_ed=jnr.randint(key,(800,),50,100)

from latimerOrigin import Latimer
import numpyro.handlers as handlers

class RampingModel:
    def __init__(self,
                 ev_batch,ed_batch,
                 dt=10e-3
                 ):
        self.dt=dt
        self.ev_batch,self.ed_batch=ev_batch,ed_batch
        self.constants()

    def constants(self):

        self.x0_mu=0.0
        self.x0_sigma=1.0

        self.beta_mu=0.0
        self.beta_sigma=0.1

        # self.omega_alpha_=1.02
        self.omega_alpha=0.02
        self.omega_beta=0.02

        self.gamma_alpha=2.0
        self.gamma_beta=0.05

    def model(self,ev_batch,ed_batch): # single trial spike train series, coherence

        T,N=ev_batch.shape

        x0=numpyro.sample("x_0",dist.Normal(self.x0_mu,self.x0_sigma))
        beta=numpyro.sample("beta",dist.Normal(
            self.beta_mu*jnp.zeros(5),
            self.beta_sigma*jnp.ones(5)))
        omega2=numpyro.sample("omega2",dist.InverseGamma(
            self.omega_alpha,self.omega_beta))

        x=numpyro.deterministic('x',jnp.ones([T,N]))
        tau=numpyro.deterministic('tau',jnp.ones(N)*T)
        gamma=numpyro.sample("gamma",dist.Gamma(
            self.gamma_alpha,self.gamma_beta))
        # FIXME: gamma becomes NaN or very large, 
        # not related to concrete distribution form, not related to gamma number
        # not related to the gibbs part

        mint=jnp.minimum(tau+1,ed_batch)

        with handlers.mask(mask=(jnp.arange(T)[...,None]<mint)):
            # y_mean=jnp.log(1.0+jnp.exp(gamma*jnp.where((x>=1.0),1.0,x)))*self.dt
            y_mean=jnp.log(1.0+jnp.exp(gamma*x))*self.dt

            numpyro.sample("obs",
                           dist.Poisson(jnp.maximum(y_mean,1e-16)),
                           obs=ev_batch)
            jax.debug.print("gamma={gamma}",gamma=gamma)

    def gibbs_fn(self,rng_key,gibbs_sites,hmc_sites):
        print("---- gibbs start ----")
        jax.debug.print("jax: ---- gibbs start ----")

        T,N=self.ev_batch.shape
        gamma=hmc_sites['gamma']
        omega2=gibbs_sites['omega2']
        beta=gibbs_sites['beta']
        x0=gibbs_sites['x_0']

        x=dist.Normal(jnp.zeros([T,N]),jnp.ones([T,N])).sample(rng_key)
        tau=jnp.ones(N)*T

        # something updated here but not related to large `gamma`

        jax.debug.print("jax: beta={beta} gamma={gamma} omega2={omega2} x0={x0}",
                        beta=beta,gamma=gamma,omega2=omega2,x0=x0)
        jax.debug.print("jax: xmax={xmax}, xmin={xmin}",xmax=x.max(),xmin=x.min())

        print("---- gibbs end ----")
        jax.debug.print("jax: ---- gibbs end ----")
        return {'beta':beta,'x_0':x0,'omega2':omega2,'x':x,'tau':tau}

from numpyro.infer import MCMC, HMCGibbs, HMC
la=RampingModel(ramp_ev.T,ramp_ed)
print(ramp_ed[:10],ramp_ed[-10:])

def MCMCAll(la,ramp_ev,ramp_ed):

    gibbs_fn=la.gibbs_fn
    hmc=HMC(la.model,
            target_accept_prob=0.8,step_size=1.0,
            )

    kernel=HMCGibbs(hmc,gibbs_fn=gibbs_fn,gibbs_sites=['beta','x_0','omega2','x','tau'])

    mcmc = MCMC(kernel, num_warmup=10, num_samples=10, progress_bar=True,num_chains=1)
    mcmc.run(jnr.PRNGKey(0),jnp.nan_to_num(ramp_ev).T,ramp_ed,init_params={
        'x_0':0.5,'beta':jnp.array([0.0,0.0,0.0,0.0,0.0]),
        'x':0.0*jnp.zeros_like(ramp_ev.T),'gamma':50.0,'omega2':0.005})

    mcmc.print_summary()
    return mcmc

mcmc=MCMCAll(la,ramp_ev,ramp_ed)

The outputs are:

[53 57 93 59 68 89 85 55 89 52] [79 85 68 68 80 92 52 94 53 97]
gamma=60.00014114379883
gamma=5.300010681152344
gamma=5.300010681152344
gamma=5.184705457665547e+21
  0%|          | 0/20 [00:00<?, ?it/s]
---- gibbs start ----
---- gibbs end ----
warmup:  20%|██        | 4/20 [00:01<00:03,  4.41it/s, 377 steps of size 1.07e-03. acc. prob=0.00]
gamma=5.184705457665547e+21
jax: ---- gibbs start ----
jax: beta=[0. 0. 0. 0. 0.] gamma=5.184705457665547e+21 omega2=0.004999999888241291 x0=0.5
jax: xmax=4.439892768859863, xmin=-4.001267433166504
jax: ---- gibbs end ----
gamma=5.184705457665547e+21
gamma=nan
gamma=nan
gamma=nan
gamma=nan
...
disadone commented 3 months ago

It seems that jax.value_and_grad in hmc_utils.py gives the key contribution to this, though I still don't understand why.

https://github.com/pyro-ppl/numpyro/blob/40565d0e80aa9b342b8a58a593b1acf449ed6e01/numpyro/infer/hmc_util.py#L252

fehiepsi commented 3 months ago

Hi @disadone, could you post the question on forum https://forum.pyro.ai/ We have a couple of inference utilities that you might want to use for diagnosing the issues of your model/data.

disadone commented 3 months ago

Hi @disadone, could you post the question on forum forum.pyro.ai We have a couple of inference utilities that you might want to use for diagnosing the issues of your model/data.

Thank you, I move my post to the forum. It would be appreciate if there is any potential_energy example.