marius311 / CMBLensing.jl

The automatically differentiable and GPU-compatible toolkit for CMB analysis.
https://cosmicmar.com/CMBLensing.jl
Other
52 stars 11 forks source link

Define posteriors via PPL #69

Closed marius311 closed 2 years ago

marius311 commented 3 years ago

An annoyance of mine has been that a "model" is defined in many disjoint places in the code. Once a DataSet is loaded, the posterior that goes along with it is in lnP. How to generate a simulation is defined in resimulate. How to Wiener filter the data for f is defined in argmaxf_lnP. You have to code all these by hand and they have to be consistent with each other, when in reality they all follow from a single definition of the "forward model" for the data.

Right now in the "ecosystem" we have CMBLensing's builtin NoLensingDataset (Gaussian CMB + beam + mask), its BaseDataSet (lensed CMB + beam + masking), we have three multifrequency DataSets in Pipeline B2, we have the UltradeepDataSet, and Federico & I are working on a patchy screening DataSet separately (patchy screening + lensing + beam + mask). Each of these has had to override each of those methods, and its kind of a pain.

This PR makes it so you only define a "model" in a single place, then all those things work based on that. You do it using a "probabilitstic programing language" (PPL) like PyMC3, Stan, Turing.jl, etc... (I tried using Julia's existing ones but they were too complex, so ended up writing my own, its only ~50 LOC). Here's what it looks like for BaseDataSet:

https://github.com/marius311/CMBLensing.jl/blob/aa6da76407ebb1122eae345d035755bd2fe99b30/src/dataset.jl#L55-L61

This says f is a Gaussian random field (Distributions.MvNormal is a multivariate normal with given mean/covariance) with mean 0 and covariance Cf, similarly for ϕ, then the mean of the data is μ = M(θ) * (B(θ) * (L(ϕ) * f)) and the data is Gaussian with this as the mean and Cn noise covariance. Once this is defined, all of these work without having to code anything else by hand:

# generate a simulation from the forward model
simulate(ds) 

# compute the log posterior given some values
logpdf(ds; f, ϕ, θ, ds.d) 

# in the mixed parameterization
logpdf(mix(ds); f°, ϕ°, θ, ds.d) 

# gradient of the log posterior
gradient((f, ϕ) -> logpdf(ds; f, ϕ, θ, ds.d), f, ϕ) 

# maximize over f, ie the Wiener filter
argmaxf_logpdf(ds, (;ϕ, θ)) 

Similarly, once the other DataSets define a @fwdmodel everything works. Here's an example for a B2 model which is multi-frequency and includes foregrounds:

@fwdmodel function (ds::GaussianDustSyncPolDataSet)(; f, ϕ, gdust, gsync, θ=(;), d)
    @unpack Cf, Cϕ, Cn, Cdust, Csync, Fcmb, Fdust, Fsync, L, M, B = ds
    f ~ MvNormal(0, Cf(θ))
    ϕ ~ MvNormal(0, Cϕ(θ)
    gdust ~ MvNormal(0, Cdust(θ))
    gsync ~ MvNormal(0, Csync(θ))
    μ = M(θ) * (B(θ) * (Fcmb * L(ϕ) * f + Fdust * gdust + Fsync * gsync))
    d ~ MvNormal(μ, Cn(θ))
end

Currently there's no performance hit at all compared to the old code (this required one minor code-duplication hack, but which will be gone once we switch from Zygote to Diffractor).

The changes to the old API are basically:

Things like MAP_joint, MAP_marg, and sample_joint will use the new API under the hood with no other modifications needed. I'm keeping the old API around for now though with deprecation warnings. I'm curious to get your feedback on this new API, and anything you think might be worth changing (this is pretty fresh so I very much welcome suggestions). @bthorne93 @kimmywu @EthanAnderes @Karthikprabhu22 @fbianchini

Hoping to merge this in the ~day timescale (things can ofcourse be changed after too).