CUQI-DTU / CUQIpy

https://cuqi-dtu.github.io/CUQIpy/
Apache License 2.0
46 stars 9 forks source link

Add (iid) truncated normal and also add grad to normal #542

Closed chaozg closed 3 weeks ago

chaozg commented 1 month ago

fixed #321 to add TruncatedNormal fixed #548 to add gradient to Normal fixed #546 to add a demo on TruncatedNormal

Description

Test with Gibbs

import cuqi
import numpy as np
np.random.seed(0)
d = cuqi.distribution.Uniform(np.array([0]),np.array([1]))
A = cuqi.model.LinearModel(lambda x: x, adjoint=lambda x: x, range_geometry=1, domain_geometry=1)
x = cuqi.distribution.TruncatedNormal(A(d), std=np.array([1]), low=np.array([-1]), high=np.array([np.Inf]))

sampling_strategy = {
    "d" : cuqi.experimental.mcmc.MALA(initial_point=np.array([0.5]), scale=0.1),
    "x" : cuqi.experimental.mcmc.MALA(initial_point=np.array([1.0]), scale=0.3)
}
joint = cuqi.distribution.JointDistribution(d, x)
sampler = cuqi.experimental.mcmc.HybridGibbs(joint, sampling_strategy=sampling_strategy)
sampler.sample(10000)
samples = sampler.get_samples()
samples["x"].plot_trace()
samples["d"].plot_trace()

(Note: here x is truncated from -1, just as expected) image image

chaozg commented 1 month ago

Thanks @amal-ghamdi and @nabriis for your review. I think I have addressed your comments so I'm requesting your further review here.

Note that the existing Normal lacks grad() #548 , so I also add it in this PR.

chaozg commented 1 month ago

With @amal-ghamdi 's help today, gradient of TruncatedNormal finally works fine as a likelihood 😄 . And now I'm requesting further review from @nabriis .

chaozg commented 1 month ago

Hi @amal-ghamdi since I just added a demo38 on the use of TruncatedNormal, I'm requesting your further review again. Plots from demo38 can be seen at #549 .

chaozg commented 3 weeks ago

Thanks, @amal-ghamdi and @nabriis for your review! I'm merging to the main now