SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
858 stars 154 forks source link

How to estimate intractable likelihood functions with normalizing flows. #791

Open itsdfish opened 1 year ago

itsdfish commented 1 year ago

I am opening an issue based on this discourse thread in which I ask about estimating intractable likelihood functions with normalizing flows. As described in this paper, normalizing flows can approximate the likelihood function of a model by learning the relationship between model parameters and the output of a model. Having a working example would be of broad interest given that many scientific fields work with complex models for which a likelihood function is unknown.

The paper is associated with a Python package called SBI. Here is a simple working example based on a LogNormal distribution. I pasted my attempt at replicating it in Julia below. Note that the package generalizes to processes that emit multiple distributions, but I have used a single distribution for simplicity.

The architectural details in the article were a little sparse, citing:

For the neural spline flow architecture (Durkan et al., 2019), we transformed the reaction time data to the log-domain, used a standard normal base distribution, 2 spline transforms with 5 bins each and conditioning networks with 3 hidden layers and 10 hidden units each, and rectified linear unit activation functions. The neural network training was performed using the sbi package with the following settings: learning rate 0.0005; training batch size 100; 10% of training data as validation data, stop training after 20 epochs without validation loss improvement.

It appears that the architecture is heavily influenced by Sequential Neural Likelihood: Fast Likelihood-free Inference with Autoregressive Flows. What appears to be the core Python code can be found here.

Also for additional background, Hossein has started a related package, but it is experimental and has no documentation.

Thank you for looking into this. I don't know much about neural network and Flux, but let me know if I can be helpful at all.

WIP Code


###########################################################################################################
#                                           load packages
###########################################################################################################
cd(@__DIR__)
using Pkg 
Pkg.activate("")
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux
using OptimizationOptimJL, Distributions
using Random
Random.seed!(3411)
###########################################################################################################
#                                           generate training data
###########################################################################################################
n_parms = 10_000
# training parameters
train_parms = map(x -> Float32.(rand(Gamma(1.0, .5), 2)), 1:n_parms)
# training samples
samples = map(p -> Float32(rand(LogNormal(p...))), train_parms)
training_data = [hcat(train_parms...); samples']
###########################################################################################################
#                                           setup network
###########################################################################################################
nn = Flux.Chain(
    # inputs are parameters μ and σ, and distribution sample
    Flux.Dense(3, 10, tanh),
    Flux.Dense(10, 10, tanh),
    Flux.Dense(10, 1, tanh),
) |> f32
tspan = (0.0f0, 50.0f0)

ffjord_mdl = DiffEqFlux.FFJORD(nn, tspan, Tsit5())

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

function cb(p, l)::Bool
    vl = loss(p)
    @info "Training" loss = vl
    false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)

res1 = Optimization.solve(optprob,
                          ADAM(0.1),
                          maxiters = 100, callback=cb)

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
                          Optim.LBFGS(),
                          allow_f_increases=false, callback=cb)
###########################################################################################################
#                                           evaluate and plot
###########################################################################################################
using Plots

test_parms = [1.5, .5]
xs = [.05:.05:20;]
true_density = pdf.(LogNormal(test_parms...), xs)
# is there a better way to get the estimated density?
est_density = map(x -> exp(ffjord_mdl([x], res2.u, monte_carlo=false)[1]), xs)
est_density = vcat(est_density...)
# plot the true and estimated densities
plot(xs, true_density)
plot!(xs, est_density)
itsdfish commented 1 year ago

I think BayesFlow is another neural network approach that would be good to have.