google / tangent

Source-to-Source Debuggable Derivatives in Pure Python
Apache License 2.0
2.32k stars 434 forks source link

Forward function cannot call nested functions #77

Closed wangkuiyi closed 6 years ago

wangkuiyi commented 6 years ago

The following example

import tangent

def f(x):
    def a(x):
        return x * x
    return a(x)

df = tangent.grad(f, verbose=0)

would fail with the exception

  File "/usr/local/lib/python2.7/dist-packages/tangent/annotate.py", line 64, in resolve
    node.id, self.func.__name__))
AttributeError: Failed to resolve name "a" used by "f".

This is because of that ResolveCalls. visit_FunctionDef

https://github.com/google/tangent/blob/3318ecab3eec71ef97702df66a59d1a2197c8de3/tangent/annotate.py#L44-L46

doesn't add the function definition of a into self.namespace, so that the following code snippet

https://github.com/google/tangent/blob/3318ecab3eec71ef97702df66a59d1a2197c8de3/tangent/annotate.py#L54-L56

cannot resolve the call from f to a.

I don't think it is necessary to support nested functions. But it seems reasonable to restrict the forms of inputs to tangent by introducing something like the Google C++ Code Style.

wangkuiyi commented 6 years ago

For more details, I am attaching the trace of the above error:

root@be77776aef91:/tangent/tests# python a.py
Traceback (most recent call last):
  File "a.py", line 8, in <module>
    df = tangent.grad(f, verbose=0)
  File "/usr/local/lib/python2.7/dist-packages/tangent/grad_util.py", line 386, in grad
    verbose=verbose)
  File "/usr/local/lib/python2.7/dist-packages/tangent/grad_util.py", line 290, in autodiff
    check_dims, verbose)
  File "/usr/local/lib/python2.7/dist-packages/tangent/grad_util.py", line 144, in autodiff_tree
    check_dims, verbose)
  File "/usr/local/lib/python2.7/dist-packages/tangent/grad_util.py", line 89, in autodiff_ast
    node = annotate.resolve_calls(func)
  File "/usr/local/lib/python2.7/dist-packages/tangent/annotate.py", line 110, in resolve_calls
    ResolveCalls(func).visit(node)
  File "/usr/lib/python2.7/ast.py", line 241, in visit
    return visitor(node)
  File "/usr/lib/python2.7/ast.py", line 249, in generic_visit
    self.visit(item)
  File "/usr/lib/python2.7/ast.py", line 241, in visit
    return visitor(node)
  File "/usr/local/lib/python2.7/dist-packages/tangent/annotate.py", line 45, in visit_FunctionDef
    self.generic_visit(node)
  File "/usr/lib/python2.7/ast.py", line 249, in generic_visit
    self.visit(item)
  File "/usr/lib/python2.7/ast.py", line 241, in visit
    return visitor(node)
  File "/usr/lib/python2.7/ast.py", line 251, in generic_visit
    self.visit(value)
  File "/usr/lib/python2.7/ast.py", line 241, in visit
    return visitor(node)
  File "/usr/local/lib/python2.7/dist-packages/tangent/annotate.py", line 66, in visit_Call
    func = resolve(node.func)
  File "/usr/local/lib/python2.7/dist-packages/tangent/annotate.py", line 64, in resolve
    node.id, self.func.__name__))
AttributeError: Failed to resolve name "a" used by "f".

I am also attaching the AST of the function definition of f for your reference:

        FunctionDef(name='f',
                    args=arguments(args=[Name(id='x', ctx=Param())],
                                   vararg=None,
                                   kwarg=None,
                                   defaults=[]),
                    body=[
                        FunctionDef(name='a',
                                    args=arguments(args=[Name(id='x', ctx=Param())],
                                                   vararg=None,
                                                   kwarg=None,
                                                   defaults=[]),
                                    body=[Return(value=BinOp(left=Name(id='x', ctx=Load()),
                                                             op=Mult(),
                                                             right=Name(id='x', ctx=Load())))],
                                    decorator_list=[]),
                        Return(
                            value=Call(func=Name(id='a', ctx=Load()),
                                       args=[Name(id='x', ctx=Load())],
                                       keywords=[],
                                       starargs=None,
                                       kwargs=None))],
mdanatg commented 6 years ago

This is indeed a known limitation - closures are not yet supported. Addressing it is a bit more involved because not only the side effects of the closure would need to be detected, but the inner function would need to be autodiffed as well.

Can you rewrite the code and pull a outside of f? I realize doing that might not always be a practical refactoring.

wangkuiyi commented 6 years ago

Oh, I see the limitation has been listed here. Sure, I can rewrite the application code to pull a out from f. Thank you @mdanatg !