HIPS / autograd

Efficiently computes derivatives of NumPy code.
MIT License
7k stars 912 forks source link

KeyError for nanmean #395

Open jluttine opened 6 years ago

jluttine commented 6 years ago

Using nanmean as

>>> import autograd
>>> autograd.value_and_grad(autograd.numpy.nanmean)(autograd.numpy.array([1,2,3]))

raises an error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-12-4dbab693af80> in <module>()
----> 1 autograd.value_and_grad(autograd.numpy.nanmean)(autograd.numpy.array([1,2,3]))

/.../lib/python3.5/site-packages/autograd/wrap_util.py in nary_f(*args, **kwargs)
     18             else:
     19                 x = tuple(args[i] for i in argnum)
---> 20             return unary_operator(unary_f, x, *nary_op_args, **nary_op_kwargs)
     21         return nary_f
     22     return nary_operator

/.../lib/python3.5/site-packages/autograd/differential_operators.py in value_and_grad(fun, x)
    127     """Returns a function that returns both value and gradient. Suitable for use
    128     in scipy.optimize"""
--> 129     vjp, ans = _make_vjp(fun, x)
    130     return ans, vjp(vspace(ans).ones())
    131 

/.../lib/python3.5/site-packages/autograd/core.py in make_vjp(fun, x)
      8 def make_vjp(fun, x):
      9     start_node = VJPNode.new_root(x)
---> 10     end_value, end_node =  trace(start_node, fun, x)
     11     if end_node is None:
     12         def vjp(g): return vspace(x).zeros()

/.../lib/python3.5/site-packages/autograd/tracer.py in trace(start_node, fun, x)
      8     with trace_stack.new_trace() as t:
      9         start_box = new_box(x, t, start_node)
---> 10         end_box = fun(start_box)
     11         if isbox(end_box) and end_box._trace == start_box._trace:
     12             return end_box._value, end_box._node

/.../lib/python3.5/site-packages/autograd/wrap_util.py in unary_f(x)
     13                 else:
     14                     subargs = subvals(args, zip(argnum, x))
---> 15                 return fun(*subargs, **kwargs)
     16             if isinstance(argnum, int):
     17                 x = args[argnum]

/.../lib/python3.5/site-packages/autograd/tracer.py in f_wrapped(*args, **kwargs)
     43             argnums = tuple(argnum    for argnum, _   in boxed_args)
     44             ans = f_wrapped(*argvals, **kwargs)
---> 45             node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents)
     46             return new_box(ans, trace, node)
     47         else:

/.../lib/python3.5/site-packages/autograd/core.py in __init__(self, value, fun, args, kwargs, parent_argnums, parents)
     28     def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
     29         self.parents = parents
---> 30         self.vjp = primitive_vjps[fun](parent_argnums, value, args, kwargs)
     31 
     32     def initialize_root(self, value):

KeyError: <function primitive.<locals>.f_wrapped at 0x7fd2cc339510>
jluttine commented 6 years ago

Perhaps related to https://github.com/HIPS/autograd/issues/384