Open lukasheinrich opened 4 years ago
this is how you call the ROOT NLL using a numpy array
import ROOT
import sys
import array
import numpy as np
import time
f = ROOT.TFile.Open(sys.argv[1])
w = f.Get('combined')
p = w.pdf('simPdf')
d = w.data('obsData')
nll = p.createNLL(d)
fcn = ROOT.RooMinimizerFcn(nll,ROOT.RooMinimizer(nll))
print(nll.getVal())
print('----')
params = fcn.GetFloatParamList()
p = np.array([params.at(i).getVal() for i in range(params.getSize())])
start = time.time()
for i in range(1000):
p = np.random.uniform(0.5,1.5,size = params.getSize())
v = fcn(p)
print(time.time()-start)
I spend some quality time w/ minuit and this is how it will propagate the error from its internal to the external variables
https://root.cern.ch/doc/master/TMinuit_8cxx_source.html#l07868
import iminuit
import jax
def toinf_single(x,bounds):
lo,hi = bounds
return jax.numpy.arcsin(2*(x-lo)/(hi-lo)-1)
def toinf(x,limit_p1, limit_p2):
a,b = x[0],x[1]
return jax.numpy.array([toinf_single(a,limit_p1),toinf_single(b,limit_p2)])
def tobnd_single(x,bounds):
lo,hi = bounds
return lo + 0.5*(hi-lo)*(jax.numpy.sin(x) +1)
def tobnd(x,limit_p1, limit_p2):
a,b = x[0],x[1]
return jax.numpy.array([tobnd_single(a,limit_p1),tobnd_single(b,limit_p2)])
def func(x):
a,b = x[0],x[1]
return 7*(a-0.1)**2 + .70*(b-0.2)**2
bounds = dict(
limit_p1 = (-1.234,7.654),
limit_p2 = (-0.987,5.674),
)
m = iminuit.Minuit(func, use_array_call=True,
forced_parameters = ['p1','p2'],
errordef = 1,
p1 = 0.5,
p2 = 0.5,
error_p1 = 0.01,
error_p2 = 0.01,
**bounds
)
m.migrad()
minim = jax.numpy.array(m.np_values())
infmin = toinf(minim,**bounds)
inf_inv_hess = jax.numpy.sqrt(jax.numpy.linalg.inv(0.5*jax.hessian(lambda x: func(tobnd(x,**bounds)))(infmin)))
def to_error(idx):
key = 'limit_p{}'.format(idx+1)
dx = inf_inv_hess[idx,idx]
up = tobnd_single(infmin[idx]+dx,bounds[key]) - minim[idx]
dn = tobnd_single(infmin[idx]-dx,bounds[key]) - minim[idx]
if dx > 1:
up = bounds[key][1] - bounds[key][0]
return (abs(up)+abs(dn))*0.5
print(minim)
print(infmin)
print(jax.numpy.sqrt(jax.numpy.linalg.inv(0.5*jax.hessian(func)(minim))))
print(inf_inv_hess)
print('error JAX ',jax.numpy.array([to_error(0),to_error(1)]))
print('error MINUIT',m.np_errors())
this means that we can recode the error calculation in any backend given that we have the hessian matrix.. in jax this is the easiest.
with this we can do at least impact plots, but not yet full covariance
this is a short script that compares our NLL to ROOT's NLL numerically. This will be useful for comparing apples to apples in terms of fit result.
compared to