google / tangent

Source-to-Source Debuggable Derivatives in Pure Python
Apache License 2.0
2.31k stars 434 forks source link

Assumption that numpy imported as np is not always true #23

Closed dmitriy-serdyuk closed 6 years ago

dmitriy-serdyuk commented 6 years ago

For example, this fails:

In [1]: import numpy

In [2]: import tangent

In [3]: def f(W, x):
   ...:   h1 = numpy.dot(x, W)
   ...:   h2 = numpy.tanh(h1)
   ...:   out = numpy.sum(h2)
   ...:   return out
   ...:
   ...: dfdW = tangent.grad(f)
   ...:

In [4]: dfdW(numpy.ones((10, 10)), numpy.ones(10))
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-4-70a6a8cf8bcb> in <module>()
----> 1 dfdW(numpy.ones((10, 10)), numpy.ones(10))

/var/folders/gr/btjlj89x0y17vf4ndzl5vklh0000gn/T/tmpa7f1cu6p/tangent_d792.py in dfdW(W, x, bout)
      4
      5     # Grad of: out = numpy.sum(h2)
----> 6     _bh2 = tangent.astype(tangent.unreduce(bout, numpy.shape(h2), None, np.
      7         _NoValue), h2)
      8     bh2 = _bh2

NameError: name 'np' is not defined
ghost commented 6 years ago

It is written in the error itself that on line 6 it is described "np." which the compiler is unable to classify , it should be : _bh2 = tangent.astype(tangent.unreduce(bout, numpy.shape(h2), None, numpy._NoValue), h2)

bartvm commented 6 years ago

@rikudoayush The code is actually generated by Tangent, so @dmitriy-serdyuk wouldn't be able to change it easily. It is indeed a bug.

What is happening is this: If a function is called without keyword arguments, we take the default values from the function signature and use those to call the backward function. In this case we take keepdims=np._NoValue from numpy.sum, and we make sure that we call the gradient function (unreduce) with that value. This works if the keyword argument defaults are None, True, etc. but if it's an actual object (like the singleton np._NoValue) it breaks if that object isn't available in the namespace of the backward pass, which is what happens here.

Fixing this would require a bit of machinery, or maybe we should just require adjoints to explicitly specify the defaults of the function they are wrapping. Having a look at it now.