Julia-Tempering / Pigeons.jl

Sampling from intractable distributions, with support for distributed and parallel methods
https://pigeons.run/dev/
GNU Affero General Public License v3.0
84 stars 10 forks source link

Running my own log-likelihood function #98

Open urirolls5987 opened 1 year ago

urirolls5987 commented 1 year ago

Hi! Thank you for an incredible package, this is really quite amazing. I was faced with an issue when it comes to running my own log-likelihood function, and sampling that similar to how I would do in Dynesty.jl:

loglikelihood(sample_params) = somefunction(sample_params) smplr = NestedSampler(ndim, nlive=500) res = dysample(loglikelihood, identity, smplr; dlogz=0.5)

In Pigeons there is an example on how to do something like this (general_target.jl) but it's unclear how to define the Prior.

For example, using a uniform prior distribution (essentially the identity).

Thanks!

miguelbiron commented 1 year ago

Hi -- sorry for the delay. I've been working on a PR that exposes a shortcut interface to pigeons in the case where your loglikelihood is given by a blackbox function, and your reference is determined by a Distribution.jl object.

In the meantime, if you want to check it out, you can install it via

]add git@github.com:Julia-Tempering/Pigeons.jl.git#bread-crumbs-api

Note: this interface is still experimental and subject to changes in the future

To replicate the black-box version of the unidentifiable model example in the docs using this simplified interface, simply run

    using Pigeons, Distributions, MCMCChains

    # define the target loglikelihood 
    function unid_log_potential(x; n_trials=100, n_successes=50) 
        p = prod(x)
        return n_successes*log(p) + (n_trials-n_successes)*log1p(-p)
    end
    ref_dist = product_distribution(Uniform(), Uniform()) # define the reference distribution 
    pt = pigeons(
        BreadCrumbs(unid_log_potential, ref_dist),
        n_rounds = 12,
        record = [traces]
    )

    # collect the statistics and convert to MCMCChains' Chains
    samples = Chains(sample_array(pt), variable_names(pt))