An annoyance of mine has been that a "model" is defined in many disjoint places in the code. Once a DataSet is loaded, the posterior that goes along with it is in lnP. How to generate a simulation is defined in resimulate. How to Wiener filter the data for f is defined in argmaxf_lnP. You have to code all these by hand and they have to be consistent with each other, when in reality they all follow from a single definition of the "forward model" for the data.
Right now in the "ecosystem" we have CMBLensing's builtin NoLensingDataset (Gaussian CMB + beam + mask), its BaseDataSet (lensed CMB + beam + masking), we have three multifrequency DataSets in Pipeline B2, we have the UltradeepDataSet, and Federico & I are working on a patchy screening DataSet separately (patchy screening + lensing + beam + mask). Each of these has had to override each of those methods, and its kind of a pain.
This PR makes it so you only define a "model" in a single place, then all those things work based on that. You do it using a "probabilitstic programing language" (PPL) like PyMC3, Stan, Turing.jl, etc... (I tried using Julia's existing ones but they were too complex, so ended up writing my own, its only ~50 LOC). Here's what it looks like for BaseDataSet:
This says f is a Gaussian random field (Distributions.MvNormal is a multivariate normal with given mean/covariance) with mean 0 and covariance Cf, similarly for ϕ, then the mean of the data is μ = M(θ) * (B(θ) * (L(ϕ) * f)) and the data is Gaussian with this as the mean and Cn noise covariance. Once this is defined, all of these work without having to code anything else by hand:
# generate a simulation from the forward model
simulate(ds)
# compute the log posterior given some values
logpdf(ds; f, ϕ, θ, ds.d)
# in the mixed parameterization
logpdf(mix(ds); f°, ϕ°, θ, ds.d)
# gradient of the log posterior
gradient((f, ϕ) -> logpdf(ds; f, ϕ, θ, ds.d), f, ϕ)
# maximize over f, ie the Wiener filter
argmaxf_logpdf(ds, (;ϕ, θ))
Similarly, once the other DataSets define a @fwdmodel everything works. Here's an example for a B2 model which is multi-frequency and includes foregrounds:
@fwdmodel function (ds::GaussianDustSyncPolDataSet)(; f, ϕ, gdust, gsync, θ=(;), d)
@unpack Cf, Cϕ, Cn, Cdust, Csync, Fcmb, Fdust, Fsync, L, M, B = ds
f ~ MvNormal(0, Cf(θ))
ϕ ~ MvNormal(0, Cϕ(θ)
gdust ~ MvNormal(0, Cdust(θ))
gsync ~ MvNormal(0, Csync(θ))
μ = M(θ) * (B(θ) * (Fcmb * L(ϕ) * f + Fdust * gdust + Fsync * gsync))
d ~ MvNormal(μ, Cn(θ))
end
Currently there's no performance hit at all compared to the old code (this required one minor code-duplication hack, but which will be gone once we switch from Zygote to Diffractor).
The changes to the old API are basically:
logpdf instead lnP
logpdf takes keyword arguments for f, ϕ, θ, d, instead of positional arguments like lnP. Also you have to pass in d explicilty, even though its technically there in ds.
The mixed posterior is now logpdf(mix(ds); ...) as opposed to the old lnP(:mix, ..., ds)
argmaxf_logpdf instead of argmaxf_lnP.
Things like MAP_joint, MAP_marg, and sample_joint will use the new API under the hood with no other modifications needed. I'm keeping the old API around for now though with deprecation warnings. I'm curious to get your feedback on this new API, and anything you think might be worth changing (this is pretty fresh so I very much welcome suggestions). @bthorne93 @kimmywu @EthanAnderes @Karthikprabhu22 @fbianchini
Hoping to merge this in the ~day timescale (things can ofcourse be changed after too).
An annoyance of mine has been that a "model" is defined in many disjoint places in the code. Once a DataSet is loaded, the posterior that goes along with it is in
lnP
. How to generate a simulation is defined inresimulate
. How to Wiener filter the data forf
is defined inargmaxf_lnP
. You have to code all these by hand and they have to be consistent with each other, when in reality they all follow from a single definition of the "forward model" for the data.Right now in the "ecosystem" we have CMBLensing's builtin
NoLensingDataset
(Gaussian CMB + beam + mask), itsBaseDataSet
(lensed CMB + beam + masking), we have three multifrequency DataSets in Pipeline B2, we have the UltradeepDataSet, and Federico & I are working on a patchy screening DataSet separately (patchy screening + lensing + beam + mask). Each of these has had to override each of those methods, and its kind of a pain.This PR makes it so you only define a "model" in a single place, then all those things work based on that. You do it using a "probabilitstic programing language" (PPL) like PyMC3, Stan, Turing.jl, etc... (I tried using Julia's existing ones but they were too complex, so ended up writing my own, its only ~50 LOC). Here's what it looks like for
BaseDataSet
:https://github.com/marius311/CMBLensing.jl/blob/aa6da76407ebb1122eae345d035755bd2fe99b30/src/dataset.jl#L55-L61
This says
f
is a Gaussian random field (Distributions.MvNormal is a multivariate normal with given mean/covariance) with mean 0 and covarianceCf
, similarly forϕ
, then the mean of the data isμ = M(θ) * (B(θ) * (L(ϕ) * f))
and the data is Gaussian with this as the mean andCn
noise covariance. Once this is defined, all of these work without having to code anything else by hand:Similarly, once the other DataSets define a
@fwdmodel
everything works. Here's an example for a B2 model which is multi-frequency and includes foregrounds:Currently there's no performance hit at all compared to the old code (this required one minor code-duplication hack, but which will be gone once we switch from Zygote to Diffractor).
The changes to the old API are basically:
logpdf
insteadlnP
logpdf
takes keyword arguments forf, ϕ, θ, d
, instead of positional arguments likelnP
. Also you have to pass ind
explicilty, even though its technically there inds
.logpdf(mix(ds); ...)
as opposed to the oldlnP(:mix, ..., ds)
argmaxf_logpdf
instead ofargmaxf_lnP
.Things like
MAP_joint
,MAP_marg
, andsample_joint
will use the new API under the hood with no other modifications needed. I'm keeping the old API around for now though with deprecation warnings. I'm curious to get your feedback on this new API, and anything you think might be worth changing (this is pretty fresh so I very much welcome suggestions). @bthorne93 @kimmywu @EthanAnderes @Karthikprabhu22 @fbianchiniHoping to merge this in the ~day timescale (things can ofcourse be changed after too).