scikit-hep / pyhf

pure-Python HistFactory implementation with tensors and autodiff
https://pyhf.readthedocs.io/
Apache License 2.0
283 stars 83 forks source link

for later reference: how to compare to ROOT NLL and MINUIT Error #764

Open lukasheinrich opened 4 years ago

lukasheinrich commented 4 years ago

this is a short script that compares our NLL to ROOT's NLL numerically. This will be useful for comparing apples to apples in terms of fit result.

import pyhf
import json

pyhf.Workspace(json.load(open('./shapesys.json')))                    
w = pyhf.Workspace(json.load(open('./shapesys.json')))   
m = w.model(measurement_name = 'GaussExample')   
d = w.data(m,)              

def root_fcn(d,m,pars):
    bin_width  = 1.0
    main,aux = m.fullpdf_tv.split(pyhf.tensorlib.astensor(d))  
    norm_aux, pois_aux = m.constraint_model.constraints_tv.split(aux)

    # poisson constraint
    const_pois = pyhf.tensorlib.sum(pyhf.tensorlib.log(pois_aux))     # is constant
    pois_comp  = m.constraint_model.constraints_poisson.make_pdf(pars).log_prob(pois_aux) + const_pois

    # normal constraint
    normal_comp = m.constraint_model.constraints_gaussian.make_pdf(pars).log_prob(norm_aux)

    #term from shape of main model
    main_rates = m.main_model.make_pdf(pars)._pdf.rate
    total_exp = pyhf.tensorlib.sum(main_rates)
    total_obs = pyhf.tensorlib.sum(main)
    main_const = - pyhf.tensorlib.sum(main*pyhf.tensorlib.log(bin_width))    # is constant
    main_comp = pyhf.tensorlib.sum(main*pyhf.tensorlib.log(main_rates/total_exp)) + main_const

    #extended term of main model
    extended_comp =  total_obs * pyhf.tensorlib.log(total_exp) - total_exp

    logl =  extended_comp + main_comp + normal_comp + pois_comp
    nll  = -logl

    print(logl, m.logpdf(pars,d)[0])

    return nll

pars = pyhf.tensorlib.astensor(m.config.suggested_init())
# pars[m.config.par_slice('bkg1Shape')] = [0.3,1.2]
nll = root_fcn(d,m,pars)
print('nll is',nll)

compared to

import ROOT
import sys

f = ROOT.TFile.Open(sys.argv[1])
w = f.Get('combined')
p  = w.pdf('simPdf')
d  = w.data('obsData')

#w.var('gamma_bkg1Shape_bin_0').setVal(0.3)
#w.var('gamma_bkg1Shape_bin_1').setVal(1.2)
w.Print()

nll = p.createNLL(d)

print(nll.getVal())
lukasheinrich commented 4 years ago

this is how you call the ROOT NLL using a numpy array

import ROOT
import sys
import array
import numpy as np
import time

f = ROOT.TFile.Open(sys.argv[1])
w = f.Get('combined')
p  = w.pdf('simPdf')
d  = w.data('obsData')

nll = p.createNLL(d)

fcn = ROOT.RooMinimizerFcn(nll,ROOT.RooMinimizer(nll))
print(nll.getVal())
print('----')

params = fcn.GetFloatParamList()
p = np.array([params.at(i).getVal() for i in range(params.getSize())])
start = time.time()

for i in range(1000):
    p = np.random.uniform(0.5,1.5,size = params.getSize())
    v = fcn(p)
print(time.time()-start)
lukasheinrich commented 4 years ago

I spend some quality time w/ minuit and this is how it will propagate the error from its internal to the external variables

https://root.cern.ch/doc/master/TMinuit_8cxx_source.html#l07868

import iminuit
import jax

def toinf_single(x,bounds):
    lo,hi = bounds
    return jax.numpy.arcsin(2*(x-lo)/(hi-lo)-1)
def toinf(x,limit_p1, limit_p2):
    a,b = x[0],x[1]
    return jax.numpy.array([toinf_single(a,limit_p1),toinf_single(b,limit_p2)])

def tobnd_single(x,bounds):
    lo,hi = bounds
    return lo + 0.5*(hi-lo)*(jax.numpy.sin(x) +1)
def tobnd(x,limit_p1, limit_p2):
    a,b = x[0],x[1]
    return jax.numpy.array([tobnd_single(a,limit_p1),tobnd_single(b,limit_p2)])

def func(x):
    a,b = x[0],x[1]
    return 7*(a-0.1)**2 + .70*(b-0.2)**2

bounds = dict(
    limit_p1 = (-1.234,7.654),
    limit_p2 = (-0.987,5.674),
)

m = iminuit.Minuit(func, use_array_call=True,
               forced_parameters = ['p1','p2'],
               errordef = 1,
               p1 = 0.5,
               p2 = 0.5,
               error_p1 = 0.01,
               error_p2 = 0.01,
               **bounds
)

m.migrad()
minim = jax.numpy.array(m.np_values())
infmin = toinf(minim,**bounds)

inf_inv_hess = jax.numpy.sqrt(jax.numpy.linalg.inv(0.5*jax.hessian(lambda x: func(tobnd(x,**bounds)))(infmin)))

def to_error(idx):
    key = 'limit_p{}'.format(idx+1)
    dx = inf_inv_hess[idx,idx]
    up = tobnd_single(infmin[idx]+dx,bounds[key]) - minim[idx]
    dn = tobnd_single(infmin[idx]-dx,bounds[key]) - minim[idx]
    if dx > 1:
        up = bounds[key][1] - bounds[key][0]
    return (abs(up)+abs(dn))*0.5

print(minim)
print(infmin)
print(jax.numpy.sqrt(jax.numpy.linalg.inv(0.5*jax.hessian(func)(minim))))
print(inf_inv_hess)
print('error JAX   ',jax.numpy.array([to_error(0),to_error(1)]))
print('error MINUIT',m.np_errors())

this means that we can recode the error calculation in any backend given that we have the hessian matrix.. in jax this is the easiest.

with this we can do at least impact plots, but not yet full covariance