Julia-Tempering / Pigeons.jl

Sampling from intractable distributions, with support for distributed and parallel methods
https://pigeons.run/dev/
GNU Affero General Public License v3.0
86 stars 10 forks source link

Advice for correctly sampling from the reference function (Stan model as input) #244

Open MarcoGallegos34 opened 4 months ago

MarcoGallegos34 commented 4 months ago

Hello,

I'm trying to use the Pigeons software with a Stan model as input. However, while following the steps in the Stan model as input to pigeons section of the Pigeons.jl documentation, I encountered the following issue:

Captura de pantalla 2024-06-29 a la(s) 12 42 46 p m

My guess is that the sampling procedure may not be appropriate given the parameter constraints of my model. I'm not sure about this and don't know where to look for different sampling functions. I'm new to Julia, so I don't know many of the packages. Is there any package or documentation you could recommend for me to look at?

Data is generated in the following way using R:

nobs = 50
original_mu = c(-15,15)
original_weights = c(1/2,1/2)
original_sd = rep(.2,2)

set.seed(9304)
x = c()
for(i in 1:nobs){

      index = sample(1:length(original_mu),1,prob=original_weights)
      x[i] = rnorm(1,original_mu[index],original_sd[index])
}

Stan code (gaussian mixture model):

data {
  int<lower=0> N;
  int<lower=0> nobs;
  vector[nobs] y;
  real<lower=0> sd_prior;
  int<lower=0,upper=1> include_prior;
  int<lower=0,upper=1> isRef;
}

parameters {
  vector[N] mu;
  vector<lower=0>[N] sd_cluster;
  simplex[N] omega;
}

model {
  // prior
  if(include_prior){
    for(i in 1:N){
      target += normal_lpdf(mu[i] | 0, sd_prior);
    }
  }

  for(i in 1:N){
    target += gamma_lpdf(sd_cluster[i] | .1, .1);
  }

  target += dirichlet_lpdf(omega | rep_vector(1.0,N));

  // likelihood
  if(!isRef){
    for(i in 1:nobs){
      vector[N] aux_log_lik_y_i;
      for(j in 1:N){
        aux_log_lik_y_i[j] = log(omega[j]) + normal_lpdf(y[i] | mu[j], sd_cluster[j]);
      }
      target += log_sum_exp(aux_log_lik_y_i);
    }
  }

}
miguelbiron commented 4 months ago

Hi -- it looks like you are passing an invalid constrained sample

  1. As far as I can tell, you have 3N constrained params, so 6 for the example you show in r -- not 14.
  2. Even then, the simplex part is not constrained---you need to normalize the last N elements so that they sum to 1.
MarcoGallegos34 commented 4 months ago

Hi, thanks for your prompt reply. If I understand correctly, setting N = 2, would a code like this work? There's no need to indicate that I'm using this explorer in the explorer argument in the function pt(), correct?

function Pigeons.sample_iid!(
        log_potential::StanLogPotential{M, S, D, StanUnidentifiableExample}, replica, shared) where {M, S, D}

    ## sample in constrained space ##

    # sampling from unit simplex for N = 2
    x_new = -log.(rand(replica.rng,2))
    norm_const = sum(x_new)
    x_final = x_new ./ norm_const

    # sampling for mu and sd_cluster parameters #
    constrained = rand(replica.rng,4)
    full_constrained = vcat(constrained,x_final)

    # transform to unconstrained space
    replica.state.unconstrained_parameters .= BridgeStan.param_unconstrain(log_potential.model, full_constrained)

end

Thanks