Support pyhf #318

Open cranmer opened 3 years ago

cranmer commented 3 years ago

Hello, I believe the previous version of BAT had an interface for working with RooFit likelihoods (that can be stored in a RooFit workspace in a ROOT file). I am curious if there are plans to provide such a module, perhaps in a separate repository. I know there is a generic API for likelihoods.

If I am correct that this was supported in the previous version, it might be worth mentioning the current status of this type of interface in BAT.jl.


oschulz commented 3 years ago

That why we have the "BAT.jl now offer a different set of functionality [...] C++ predecessor." in the docs. :-)

That being said, we do actually want to bring the RootFit interface back at some point, in principle. It would be an add-on package, due to it's additional dependencies. It's not trivial at the moment though, since Cxx.jl doesn't support current Julia versions (hopefully that will change, there's people working on it). BAT.jl does have experimental support for external likelihoods that run in a separate process and communicate with BAT.jl via pipes - that works without C++ calls from Julia.

One of the main reasons why we haven't prioritized this a lot is that we haven't had any active use cases with RootFit. If you do plan do use BAT.jl with RootFit though, we'll be happy help!

cranmer commented 3 years ago

Ok, thanks! [It would be nice to write an example interface for pyhf as well: ]

@oschulz Can I ask an unrelated question (sorry to abuse the GH issue). Can one read/write a BAT model from a file? I don't mean the posterior samples, but the model likelihood and prior itself? From what I see in the documentation, I don't think so.

cranmer commented 3 years ago

Ok, talking with Philip it sounds like maybe you can write the model to a file with built in Julia serialization.

Looking at I see

For lossless storage of arbitrary Julia objects, the only other complete solution appears to be Julia's serializer, which can be accessed via the serialize and deserialize commands. However, because the serializer is also used for inter-process communication, long-term backwards compatibility is currently uncertain.

oschulz commented 3 years ago

Indeed, JLD2 can serialize most Julia structures natively, but is not a long-term data preservation format. BAT itself is intended to work with user-provided likelihoods that may contain arbitrary code and is not limited to a DSL, so it doesn't provide an inherent long-term stable serialization format. But if there's interest, we could of course support loading specific formats for serialized likelihoods, for example via add-on packages to BAT.

oschulz commented 3 years ago

I actually discussed just that with @lukasheinrich today. :-) Having pyhf import (and possibly re-export from a pyhf DSL in Julia) capability would definitely be great. It would also require some manpower though (e.g. coupled to a thesis or so) since pyhf is not a tiny spec.

It wouldn't necessarily have to be BAT-specific/dependent, one could think about a standalone PyHF Julia package that implements (e.g.) the lightweight DensityInterface API (BAT will support DensityInterface very soon).

oschulz commented 3 years ago

CC @mmikhasenko

lukasheinrich commented 3 years ago

this is great..

one of the easiest pyhf models is the following (with s, b, d being hyper-parameters) joint model over two poissons. What would this look like inn BAT?

p(n, a | mu,gamma) = Pois( n | mu*s + gamma*b) Pois( a | gamma*d )
lukasheinrich commented 3 years ago

also as a cross-ref.. we used ot have a RooFit/python bridge here: but certainly it would be improved.

oschulz commented 3 years ago

In Julia/BAT, it would look like this (for example):

using Distributions, UnPack, BAT, ValueShapes

likelihood = let n = 7, a = 4, s = 1, b = 2, d = 3
    function (v)
        @unpack μ, γ = v
        ll = logpdf(Poisson(μ*s + γ*b), n) + logpdf(Poisson(γ*d), a)

v = (μ = 3, γ = 4)


One could also write an explicit forward model function that returns a distribution over NamedTuples (n = ..., a = ...) and simply use the logpdf of those distributions in the likelihood:

forward_model = let s = 1, b = 2, d = 3
    function (v)
        @unpack μ, γ = v
            n = Poisson(μ*s + γ*b),
            a = Poisson(γ*d)

data = (n = 7, a = 4)

likelihood = let forward_model = forward_model, data = data
    v -> LogDVal(logpdf(forward_model(v), data))


(Note: The let-statements are just there to ensure the current values are captured in a closure, since access to non-const global variables is slow in Julia because the compiler can't infer their type at code-generation time).

Of course one would still need to define a prior for (μ, γ) (BAT is a Bayesian toolkit, after all ;-) ).

lukasheinrich commented 3 years ago

ok step 0 works :)

Python 3.7.2 (default, Jul  3 2020, 03:38:00) 
[Clang 10.0.0 (clang-1000.10.44.4)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import pyhf
>>> m = pyhf.simplemodels.uncorrelated_background([5],[50],[5])
>>> m.logpdf([1.0,1.0],[50.,100])
julia> using Distributions, UnPack, BAT, ValueShapes

julia> likelihood = let n = 50, a = 100, s = 5, b = 50, d = 100
           function (v)
               @unpack μ, γ = v
               ll = logpdf(Poisson(μ*s + γ*b), n) + logpdf(Poisson(γ*d), a)
#1 (generic function with 1 method)

julia> v = (μ = 1, γ = 1)
(μ = 1, γ = 1)

julia> likelihood(v)

is therre a easy way to get gradients of likelihood @oschulz (both wrt do μ, γ or n,a,s,b,d)

niklasschmitz commented 2 years ago

Disclaimer: I'm not familiar with BAT.jl internals, so please do correct me.

For gradients, an easy way seems to be

julia> using Distributions, UnPack, BAT, ValueShapes

julia> likelihood = let n = 50, a = 100, s = 5, b = 50, d = 100
           function (v)
               @unpack μ, γ = v
               ll = logpdf(Poisson(μ*s + γ*b), n) + logpdf(Poisson(γ*d), a)
#1 (generic function with 1 method)

julia> v = (μ = 1, γ = 1)
(μ = 1, γ = 1)

julia> likelihood(v)

and then to use the reverse-mode AD package Zygote.jl

julia> using Zygote

julia> gradient(v -> likelihood(v).logval, v)
((μ = -0.4545454545454547, γ = -4.545454545454547),)

To differentiate w.r.t. parameters n,a,s,b,d one could similarly introduce another lambda function. Alternatively, another fun Zygote feature is differentiation w.r.t. a function object directly: This returns NamedTuple gradients w.r.t. closed-over variables:

julia> gradient(likelihood -> likelihood(v).logval, likelihood)
((n = nothing, a = nothing, s = -0.09090909090909094, b = -0.09090909090909094, d = 0.0),)

or all together in one sweep

julia> gradient((l, v) -> l(v).logval, likelihood, v)
((n = nothing, a = nothing, s = -0.09090909090909094, b = -0.09090909090909094, d = 0.0), (μ = -0.4545454545454547, γ = -4.545454545454547))
lukasheinrich commented 2 years ago

awesome.. works nicely.. for reference our jax stuff in pyhf

import pyhf
import jax
m = pyhf.simplemodels.uncorrelated_background([5],[50],[5])


DeviceArray([[-0.45454545, -4.54545455]], dtype=float64)
oschulz commented 2 years ago

Thanks for the nice write-up, @niklasschmitz!

In addition to Zygote, one can of course also use ForwardDiff (faster with very few parameters due to lower overhead), Enzyme (though it crashes here, doesn't like something in the Poisson-logpdf, Enzyme ist still quite experimental), etc. ForwardDiff can only handle a single array as input, but BAT handles that internally via ValueShapes shaped-to-flat-array transformation based on the prior.

Note: I'm in the process of adding an alternative to the LogDVal "this is a log"-tagging mechanism, should be on the master branch soon.

Moelf commented 2 years ago

pyhf is free via PyCall, for example, this hello world from the main doc page:

using PyCall
pyhf = pyimport("pyhf")

model = pyhf.simplemodels.uncorrelated_background(
    signal=PyVector([12.0, 11.0]), bkg=PyVector([50.0, 52.0]), bkg_uncertainty=PyVector([3.0, 7.0])

data = [[51, 48]; model.config.auxdata]
test_mu = 1.0

CLs_obs, CLs_exp = pyhf.infer.hypotest(
    test_mu, data, model, test_stat="qtilde", return_expected=true

print("Observed: $CLs_obs, Expected: $CLs_exp")

# Observed: fill(0.052514974238085446), Expected: fill(0.06445320535890237)

PyVector is not needed after:

btw, I think it's useful to have some set of "common example" to show pyhf <-> BAT albeit not everything can be translated as easy.

oschulz commented 2 years ago

pyhf is free via PyCall, for example

Oh, sure! Might still be very useful to have a Julia implementation longer-term, though, for deeper integration, to use Julia autodiff, and so on.

oschulz commented 2 years ago

But you're right, we can use pyhf right now via PyCall, would be good to put and example for pyhf + BAT.jl together. Have to think how to handle priors, though.

lukasheinrich commented 2 years ago

the current model is that pyhf comes with "auxiliary measurements" in the box so

p(data | μ,ν) p(aux| ν), where the p(aux| ν) is a surrogate for the measurements performed by e.g. detector groups

so the priors that are needed are p(µ) and p(ν) .. the former would be your prior belief on BSM parameters while the latter would be a "ur-prior" on nuisance parameters (though that'd likely be non-informative, most information about the detector performance comes from those aux. measurement)

we also provide APIs to only have p(data | μ,ν) if you'd like to provide the prior p(ν) = p(ν|aux) yourself

oschulz commented 2 years ago

Does pyhf export an API for it's autodiff-pullback?

lukasheinrich commented 2 years ago

you can get the pull back via jax for example as, but for now this is backend specific,

import pyhf
import jax
m = pyhf.simplemodels.uncorrelated_background([5],[50],[7])
lhood = lambda p: m.logpdf(p,[50]+m.config.auxdata)
v, pullback = jax.vjp(lhood,jax.numpy.array([1.,1.]))

if you want the lhood gradient, we have more backend-indpendent apis

oschulz commented 2 years ago

Neat! We can wrap that in a ChainRulesCore.rrule.

lukasheinrich commented 2 years ago

@niklasschmitz has some thoughts on that

oschulz commented 2 years ago

It'll basically just be something like this:

using DensityInterface, ChainRulesCore, PyCall
const jax = pyimport("jax")

struct PyHfLikelihood{LT}

DensityInterface.logdensityof(d::PyHfLikelihood, x::AbstractVector{<:Real}) = d.log_likelihood(x)

function ChainRulesCore.rrule(d::PyHfLikelihood, x::AbstractVector{<:Real})
    y, pypullback = jax.vjp(lhood,jax.numpy.array(x))
    function pyhf_pullback(dy)
        dx = @thunk pyhf_pullback(unthunk(dy))
        return NoTangent(), dx
    return y, pyhf_pullback

Something along these lines should already be a BAT and Zygote-compatible likelihood density.

Moelf commented 2 years ago

re-useing pyhf's likelihood seems to be free:

using PyCall, BAT
pyhf = pyimport("pyhf")
# spec is a JSON/Dict per pyhf standard
pdf_pyhf = pyhf.Model(spec)
data_pyhf = [v_data; pdf_pyhf.config.auxdata]

function likelihood_pyhf(v)
    (;μ, θ) = v
    LogDVal(only(pdf_pyhf.logpdf([μ, θ], data_pyhf)))

then for BAT

prior = BAT.NamedTupleDist(
  μ = Uniform(0, 4),
  θ = Normal()

posterior = PosteriorDensity(likelihood_pyhf, prior)
best_fit = bat_findmode(posterior).result

just works. Of course this case it's probably Optim.jl handling the fit without gradient.

lukasheinrich commented 2 years ago

nice! one thing that might be nice to try is to use @philippeller's batty from python

also note: the pyhf pdf comes with data-driven constraints that perhaps in a native Bayesian approach would be given as priors. I.e. the priors BAT should receive are "ur-priors" with this more detailed PDF are likely more uninformative, i.e. θ might rather be a wide normal (with sigma > 1) as the auxdata already provides some data to constrain

oschulz commented 2 years ago

@Moelf provided a nice initial example in #345.

Moelf commented 2 years ago

Still need to figure out how to define the correct rrule, right now it would crash if you try to sample the posterior

oschulz commented 2 years ago

And we should make life easier on the user by looking into the pyhf model to get the order of variables, so things won't go badly if they are in a different order in the prior.