google / tangent

Source-to-Source Debuggable Derivatives in Pure Python
Apache License 2.0
2.32k stars 434 forks source link

Unable to call Hessian-vector product function if function calls other functions #93

Open wiso opened 5 years ago

wiso commented 5 years ago

Sorry if the example is not very minimal. I have a function defined as

sum_C -log poisson(observed_C | sum_P {eff_CP @ (n_P * mu_P)})

where everything is constant except for mu_P. eff_CP is the element of a matrix, while the others are 1D vectors. @ is matrix multiplication.

import numpy as np
from scipy import stats
from scipy.special import gammaln, digamma

def logpoisson(lam, n):
    return n * np.log(lam) - lam - gammaln(n + 1.0)

import tangent
from tangent.grads import adjoint
@adjoint(gammaln)
def dgammaln(result, x):
  d[x] = d[result] * digamma(x)

def hessian(f):
    vhp = tangent.grad(tangent.grad(f))
    last_arg = vhp.__code__.co_varnames[vhp.__code__.co_argcount - 1]  # bad solution
    def hf(x, *args):
        H = []
        for i in range(x.size):
            v = np.eye(1, x.size, i)[0]
            H.append(vhp(x, *args, **{last_arg:v}))
        return np.array(H)
    return hf

def function(pars, obs, ntrue, efficiencies):
    mus = pars[:4]  # I need this in futures
    # pars now has size 4, so mus and pars are the same thing
    # if I use directly pars, it works
    expected = np.dot(efficiencies, ntrue * mus)
    # p = logpoisson(expected, obs)  # this doesn't work
    p = obs * np.log(expected) - expected - gammaln(obs + 1.0)  # this works
    return -np.sum(p)

grad = tangent.grad(function)
hess = hessian(function)

ntrues = np.array([8109.63147251,  636.80207692,  362.09635052,  105.68754852])
efficiencies = np.array([[8.48528557e-02, 1.16218361e-02, 1.60701261e-02, 2.14594047e-04],
       [1.49223448e-01, 2.25235106e-02, 3.32789360e-02, 5.09044476e-04],
       [7.06510161e-02, 5.48355317e-02, 4.54045768e-02, 1.75636472e-03],
       [3.41855208e-02, 6.04847808e-02, 3.98815566e-02, 1.87227331e-03],
       [6.47040512e-03, 1.91409691e-02, 1.20408359e-02, 8.01323682e-04],
       [1.64533523e-03, 5.70742974e-03, 3.62071954e-03, 3.07357073e-04],
       [1.82974652e-02, 2.61130198e-02, 3.61665791e-02, 2.17032553e-02],
       [1.48041680e-02, 2.95195735e-02, 3.27501289e-02, 2.23425697e-02],
       [6.04401223e-03, 1.28281121e-02, 1.48972295e-02, 9.19357318e-03]])

obs = np.dot(efficiencies, ntrues)

# this works, return 0, 0, 0, 0 since this is the minimum
grad(np.array([1, 1, 1, 1]), obs, ntrues, efficiencies)
# we can introduce more argument, the derivative wrt them is 0
grad(np.array([1.2, 1, 1, 1, 100]), obs, ntrues, efficiencies)

hess(np.array([1, 1, 1, 1, 100.]), obs, ntrues, efficiencies)

If in the function I call another function

p = logpoisson(expected, obs)  # this doesn't work

instead of inlining when computing the hessian I get

IndexError                                Traceback (most recent call last)
<ipython-input-19-83985776cc70> in <module>
----> 1 hess(np.array([1, 1, 1, 1, 100.]), obs, ntrues, efficiencies)

<ipython-input-2-9aa78801ca8b> in hf(x, *args)
      6         for i in range(x.size):
      7             v = np.eye(1, x.size, i)[0]
----> 8             H.append(vhp(x, *args, **{last_arg:v}))
      9         return np.array(H)
     10     return hf

/tmp/tmph8lq_o2h/tangent_7e31.py in ddfunctiondparsdpars(pars, obs, ntrue, efficiencies, bminus_np_sum_p, bbpars)
     54 
     55     # Beginning of backward pass
---> 56     _4 = tangent.pop(_stack, '_c541ac94')
     57 
     58     # Grad of: bpars[_4] = _bpars

~/venv3/lib/python3.7/site-packages/tangent/utils.py in pop(stack, op_id)
    669   """
    670   if __debug__:
--> 671     pushed_value, pushed_op_id = stack.pop()
    672     assert pushed_op_id == op_id, 'Wanted %s, got %s' % (op_id, pushed_op_id)
    673   else:

~/venv3/lib/python3.7/site-packages/tangent/utils.py in pop(self)
     59 
     60   def pop(self):
---> 61     return self._stack.pop()
     62 
     63   def __len__(self):

IndexError: pop from empty list