CUQI-DTU / CUQIpy

https://cuqi-dtu.github.io/CUQIpy/
Apache License 2.0
48 stars 9 forks source link

NUTS got stuck when sampling standard normal distribution #538

Open chaozg opened 2 months ago

chaozg commented 2 months ago

Description

NUTS got stuck when sampling a standard normal distribution.

"this is very good diagnosis. I am not sure what is the issue, it seems stuck where we have high PDF (around zero)" -- from @amal-ghamdi

Example to reproduce

import cuqi
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(0)

gaussian = cuqi.distribution.Gaussian(np.array([0.0]), np.array([1.0]))
sampler = cuqi.experimental.mcmc.NUTS(gaussian, initial_point=np.array([0.1]))
sampler.warmup(10000)
sampler.sample(10000)
samples = sampler.get_samples()
samples.plot_trace()

image