bgroenks96 / SimulationBasedInference.jl

A flexible toolkit for simulation based inference in Julia
MIT License
15 stars 4 forks source link

Unexpected behavior when providing initial_ens when solving inference problems #4

Closed JunmiHu closed 1 month ago

JunmiHu commented 1 month ago

Description Providing the initial_ens to the inference problem does not actually perform the simulation using the initial ensemble, but rather seems to generate a new ensemble using the prior to use as the initial ensemble. This seems to occur for EnIS, ES-MDA, and EKS. I have not tested this for the other algorithms.

To Reproduce I have provided an example code below based on the linearode example to reproduce the issue for the EnIS() simulation. Note, however, that changing the inference algorithm to ES-MDA or EKS yields the same non-expected initial ensemble.

using SimulationBasedInference
using OrdinaryDiffEq
import Random

const rng = Random.MersenneTwister(1234);

"""Defining factory for the linear ode problem simulation. """
function problem_factory(ode_func, t_data, solver)

    function problem_simulation(θ)
        println(θ)
        prob = ODEProblem(ode_func, 1.0, (t_data[begin], t_data[end]), θ[1])
        sol = solve(prob, solver, saveat = t_data)
        return hcat(sol.u...)
    end

    return problem_simulation
end

# define ode_func
ode_func(u,p,t) = -p[1]*u; 

# Define true parameter
α_true =[0.2]

# Define time span and observation times
tspan = (0.0,10.0)
dt = 0.2
tsave = tspan[begin]:dt:tspan[end]
n_obs = length(tsave)

# Define observable and forward problem
observable = SimulatorObservable(:y, state -> state.u, (n_obs,))
ode_solver = Tsit5()
forward_prob = SimulatorForwardProblem(problem_factory(ode_func, tsave, ode_solver), θ_true, observable)

# Generating synthetic data by running the forward solution and adding noise
forward_sol = solve(forward_prob);
true_obs = get_observable(forward_sol, :y)
noise_scale = 0.05
noisy_obs = true_obs .+ noise_scale*randn(rng, n_obs);

# Setting priors
model_prior = prior(α=Beta(2,2));
noise_scale_prior = prior(σ=Exponential(0.1));

# Assign a simple Gaussian likelihood for the obsevation.
lik = IsotropicGaussianLikelihood(observable, noisy_obs, noise_scale_prior);

# We now have all of the ingredients needed to set up and solve the inference problem.
# We will start with a simple ensemble importance sampling inference algorithm.
inference_prob = SimulatorInferenceProblem(forward_prob, model_prior, lik)
enis_sol = solve(inference_prob, EnIS(), ensemble_size=8, rng=rng);

init_ens = get_transformed_ensemble(enis_sol, 1)

enis_2 = solve(inference_prob, EnIS(), ensemble_size=8, rng=rng; initial_ens = init_ens);
true_ens = get_transformed_ensemble(enis_2, 1)

println(sum(true_ens .-init_ens))

# The same behavior for the initial ensemble is observed for ES-MDA and EKS
#esmda_2 = solve(inference_prob, ESMDA(), ensemble_size=8, rng=rng; initial_ens = init_ens);
#eks_2 = solve(inference_prob, EKS(), ensemble_size=8, rng=rng; initial_ens = init_ens);

When running this code, I obtain that init_ens and true_ens are not the same. This can also be seen for the initial ensembles for the ES-MDA and EKS solvers.

Expected behavior I would expect the println statement in the problem_simulation function to print out the given initial ensemble. However, as seen, the two are different. I believe the current behavior is not intended.

Notes on the issue In the EnIS case, this seems to be due to a premature call to prob.prob_func in SciMLBase/src/ensemble/basic_ensemble_solve.jl#L183: new_prob = prob.prob_func(_prob, i, iter) , which updates the ensembles based on the definition of the prob_func in SimulationBasedInference/src/ensembles/ensemble_solver.jl#L237-241:

    ensprob = EnsembleProblem(
        initial_prob;
        prob_func=(prob,i,repeat) -> prob_func(prob, param_map(ens[:,i])),
        output_func=(sol,i) -> output_func(sol, i, iter)
    )

System Info OS: MacOS 14.4.1 julia version: 1.10.4

Let me know if you need any more information. Thank you for taking the time to develop the package!

bgroenks96 commented 1 month ago

Fixed by ecd7de58

The problem was that the initial ensemble was assumed to be provided in the unconstrained space, and so the mismatch was due to the transform not being applied. This is obviously counter-intuitive and has been corrected.

Thanks for reporting the issue.

bgroenks96 commented 1 month ago

Regression test can be found here.