google / tangent

Source-to-Source Debuggable Derivatives in Pure Python
Apache License 2.0
2.32k stars 434 forks source link

Function's derivative only returns the gradient of the first argument #71

Closed tonyyang-svail closed 6 years ago

tonyyang-svail commented 6 years ago

Hi, I was trying to play around with Tangent in the playground. And looks like tangent only supports differentiating function with a single input.

To reproduce:

def f(x, y):
  a = _mul(x, y)
  b = _mul(x, y)
  c = a + b
  return c

def _mul(m, n):
  out = m * n
  return out

import tangent
df = tangent.grad(f, verbose=1)

Generated code:

def dfdx(x, y, bc=1.0):
    # Initialize the tape
    _stack = tangent.Stack()
    _substack = tangent.Stack()
    tangent.push_stack(_stack, _substack, '_b10af127')
    a = pri__mulm(_substack, x, y)
    _substack = tangent.Stack()
    tangent.push_stack(_stack, _substack, '_76955021')
    b = pri__mulm(_substack, x, y)
    c = a + b
    assert tangent.shapes_match(c, bc
        ), 'Shape mismatch between return value (%s) and seed derivative (%s)' % (
        numpy.shape(c), numpy.shape(bc))

    # Grad of: c = a + b
    _ba = tangent.unbroadcast(bc, a)
    _bb = tangent.unbroadcast(bc, b)
    ba = _ba
    bb = _bb

    # Grad of: b = _mul(x, y)
    _substack = tangent.pop_stack(_stack, '_76955021')
    dxs = _d_muldm(_substack, bb, x, y)
    _bx2 = dxs[0]
    bx = _bx2

    # Grad of: a = _mul(x, y)
    _substack = tangent.pop_stack(_stack, '_b10af127')
    dxs = _d_muldm(_substack, ba, x, y)
    _bx = dxs[0]
    bx = tangent.add_grad(bx, _bx)
    return bx

def pri__mulm(_stack, m, n):
    out = m * n
    result = out
    tangent.push(_stack, result, '_a6173701')
    return out

def _d_muldm(_stack, bout, m, n):
    result = tangent.pop(_stack, '_a6173701')

    # Grad of: out = m * n
    _bm = tangent.unbroadcast(bout * n, m)
    bm = _bm
    return bm, result

I would expect dfdx to return both bx and by, instead of just bx.

alexbw commented 6 years ago

Try tangent.grad(f, wrt=(0,1), verbose=1)

tonyyang-svail commented 6 years ago

It worked. Thanks.