pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.18k stars 238 forks source link

Inconsistent MCMC Results Based On * Operator #1488

Closed mlpotter closed 2 years ago

mlpotter commented 2 years ago

Issue

I receive inconsistent MCMC results based on simply using the operator or not. Model1 (correct result) and Model2 (incorrect result) gives different answers despite all the mathematical description being equivalent. The only difference between model2 versus model1 is the operator.

(jnp.exp(-(jnp.power(t2*lam,  k) - jnp.power(t1*lam,  k)))) # gives the wrong answer
(jnp.exp(-(jnp.power(t2,  k) - jnp.power(t1,  k))*jnp.power(lam,k) ))) # correct answer

Package Versions

OS

Edition Windows 10 Home Version 21H2 Installed on ‎12/‎8/‎2021 OS build 19044.2006 Experience Windows Feature Experience Pack 120.2212.4180.0

Code

# -*- coding: utf-8 -*-
import numpy as np

import torch
from torch.distributions import Weibull

from jax import random
import jax.numpy as jnp
import jax

from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.handlers import trace
import numpyro.distributions as dist
import numpyro

from time import time
import random as rnd

numpyro.enable_x64()
cpu_cores = 2
numpyro.set_host_device_count(cpu_cores)

torch.manual_seed(0)
rnd.seed(0)
np.random.seed(0)
print(np.__version__)
print(numpyro.__version__)
print(jax.__version__)
print(torch.__version__)

def generate_data_no_features(N,k,lam,intervals=6,intervals_length=4):
    print("Shape Parameter",k)
    print("Rate Parameter",lam)

    weibull = Weibull(torch.ones(N,1)*1/lam,torch.ones(N,1)*k)

    t_fail = weibull.sample()

    t1 = []; t2 = []; y = [];
    print(f"Number of intervals: {intervals}")
    print(f"Length of intervals: {intervals_length}")
    for i in torch.arange(intervals):
        t_start = intervals_length*i
        t_end = intervals_length*(i+1)
        failed = torch.logical_and(t_fail <= t_end,t_fail > t_start).type(torch.float)

        y.append(failed.ravel())
        t1.append(t_start*torch.ones_like(failed).ravel())
        t2.append(t_end*torch.ones_like(failed).ravel())

        if len(failed) == 0:
            break

        t_fail = t_fail[np.logical_not(failed).type(torch.bool)]

    t1 = torch.hstack(t1).unsqueeze(1)
    t2 = torch.hstack(t2).unsqueeze(1)
    y = torch.hstack(y).squeeze()
    print(f"The timeline ends at {t_end}")
    return t1.type(torch.float).squeeze().numpy(),t2.type(torch.float).squeeze().numpy(),y.type(torch.float).squeeze().numpy()

tfix = 5

#(jnp.exp(-(jnp.power(t2*lam,  k) - jnp.power(t1*lam,  k)))) gives the wrong answer
#(jnp.exp(-(jnp.power(t2,  k) - jnp.power(t1,  k))*jnp.power(lam,k) ))) ## correct answer

def model1(t1, t2, y=None):
    k = numpyro.sample("k", dist.LogNormal(0,1))
    r = numpyro.sample("r", dist.Uniform(0,1))

    lam = numpyro.deterministic("lam", jnp.power(-jnp.log(r), 1 / (k)) / tfix)

    p = 1.0 - jnp.exp(-(jnp.power(t2,  k) - jnp.power(t1,  k))*jnp.power(lam,k) )

    numpyro.sample("likelihood", dist.Bernoulli(probs=p), obs=y)

def model2(t1, t2, y=None):
    k = numpyro.sample("k", dist.LogNormal(0,1)) 
    r = numpyro.sample("r", dist.Uniform(0,1))

    lam = numpyro.deterministic("lam", jnp.power(-jnp.log(r), 1 / (k)) / tfix)

    p = 1.0 - jnp.exp(-(jnp.power(t2*lam,  k) - jnp.power(t1*lam,  k)) )

    numpyro.sample("likelihood", dist.Bernoulli(probs=p), obs=y)

if __name__ == "__main__":
    t1,t2,y = generate_data_no_features(100,1.1,.01,30,4)

    rng_keys = jax.random.split(random.PRNGKey(123), cpu_cores)
    print("Model 1")
    nuts_kernel = NUTS(model1)
    mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=cpu_cores, chain_method='parallel')

    with numpyro.handlers.seed(rng_seed=1):
        exec_trace = trace(model1).get_trace(t1, t2, y)
        print(numpyro.util.format_shapes(exec_trace, compute_log_prob=True))

    start = time()
    mcmc.run(rng_keys, t1=t1, t2=t2, y=y)
    end = time()

    print(mcmc.print_summary(exclude_deterministic=False))
    print("Compile time + Sampling Time {:.4f}".format(end - start))

    print("Model 2")
    nuts_kernel = NUTS(model2)
    mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=cpu_cores, chain_method='parallel')

    with numpyro.handlers.seed(rng_seed=1):
        exec_trace = trace(model2).get_trace(t1,t2,y)
        print(numpyro.util.format_shapes(exec_trace,compute_log_prob=True))

    start = time()
    mcmc.run(rng_keys, t1=t1, t2=t2, y=y)
    end = time()
    print("Compile time + Sampling Time {:.4f}".format(end - start))

    print(mcmc.print_summary(exclude_deterministic=False))
martinjankowiak commented 2 years ago

@mlpotter i don't think this is surprising. libraries like jax and pytorch do computations using floating point operations not exact arithmetic. so computations that are equivalent mathematically won't be equivalent in practice due to the accumulation of floating point errors, numerical under/overflow, etc.

factoring like in "## correct answer" is just good practice, which explains why you get a reasonable answer in that case. depending on the nature of the mathematical operations involved, doing this sort of thing can be essential for good results

feel free to re-open this issue if you think this is an actual bug (and not just a pitfall of doing floating point arithmetic, which is basically common to all performant numerical systems)

mlpotter commented 2 years ago

@martinjankowiak I tried this unfactorized versus factorized version in rstan. Rstan does not seem to have the same issue as numpyro and returns the correct parameters for both methods.

fehiepsi commented 2 years ago

Quoted from @mlpotter in #1493

Reopening https://github.com/pyro-ppl/numpyro/issues/1488 because of comparison with rstan in R. @martinjankowiak I tried this unfactorized versus factorized version in rstan. Rstan does not seem to have the same issue as numpyro and returns the correct parameters for both methods. Not sure if rstan does optimization under the hood that numpyro does not, but now I am suspecting possible bug.

fehiepsi commented 2 years ago

@mlpotter It is better to raise those numerical issues to jax folks. You can just create a function with arguments are latent variables and try to see if there are differences in the evaluation. You use the following code to make an issue in jax. I think wrong grad computation happens when t1 takes value 0 in jnp.power(t1 * lam, k). In those cases, it is better to factor as @martinjankowiak suggested above.

import jax.numpy as jnp

t1 = jnp.array([0., 1.])
t2 = jnp.array([1., 2.])

def f1(k, r):
    lam = jnp.power(-jnp.log(r), 1 / k)
    p = jnp.exp(-(jnp.power(t2*lam, k) - jnp.power(t1*lam, k)))
    return p.sum()

def f2(k, r):
    lam = jnp.power(-jnp.log(r), 1 / k)
    p = jnp.exp(-(jnp.power(t2, k) - jnp.power(t1, k)) * jnp.power(lam, k))
    return p.sum()

print(jax.device_get(jax.value_and_grad(f1, (0, 1))(0.1, 0.9)))
print(jax.device_get(jax.value_and_grad(f2, (0, 1))(0.1, 0.9)))

which returns

(array(1.89246643), (array(nan), array(nan)))
(array(1.89246643), (array(-0.07768232), array(1.0791475)))