pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.33k stars 22.47k forks source link

torch.fx failed when tracing functions from other Libs. #73469

Open yinpeiqi opened 2 years ago

yinpeiqi commented 2 years ago

šŸ› Describe the bug

When we are trying to use torch.fx to trace functions in other libs, I met the below failure:

import torch
import torch.nn as nn
from torch.fx import Tracer, GraphModule
import dgl
from dgl.nn.functional import edge_softmax

class Test(nn.Module):
    def forward(self, graph, x):
        return edge_softmax(graph, x)

m = Test()
traced = GraphModule(m, Tracer(autowrap_functions=(edge_softmax,)).trace(m))
print(traced.graph)
print(traced.code)

The output is:

graph():
    %graph : [#users=1] = placeholder[target=graph]
    %x : [#users=1] = placeholder[target=x]
    %edge_softmax : [#users=1] = call_function[target=dgl.ops.edge_softmax.edge_softmax](args = (%graph, %x), kwargs = {})
    return edge_softmax

torch.fx._symbolic_trace.wrap("dgl_ops_edge_softmax_edge_softmax")

def forward(self, graph, x):
    edge_softmax = dgl_ops_edge_softmax_edge_softmax(graph, x);  graph = x = None
    return edge_softmax

The '.' was replaced by '_' in the custom functions. For this issue, I checked the relative functions, and I think in torch/fx/graph.py, line 314:

        def add_global(name_hint: str, obj: Any):
            """Add an obj to be tracked as a global.

            We call this for names that reference objects external to the
            Graph, like functions or types.

            Returns: the global name that should be used to reference 'obj' in generated source.
            """
            if _is_from_torch(obj) and obj != torch.device:  # to support registering torch.device
                # HACK: workaround for how torch custom ops are registered. We
                # can't import them like normal modules so they must retain their
                # fully qualified name.
                return _get_qualified_name(obj)

This issue is caused by this 'add_global' function. Could we change here to support import function from other libs?

Versions

github master branch

cc @ezyang @SherlockNoMad

pommedeterresautee commented 2 years ago

FWIW, I get an error the replace function when using numpy in custom operator

# load whatever model with a `matmul`

class MyFunc(torch.autograd.Function):

    @staticmethod
    def forward(ctx, a, b):
        # force to go through ext lib, np here
        c = a.cpu().detach().numpy() + b.cpu().detach().numpy()
        return torch.from_numpy(c)

def pattern(x, y):
    return torch.matmul(x, y)

def replacement(x, y):
    return MyFunc.apply(x, y)

replace_pattern(fx_graph, pattern, replacement)  # <- crash

produces:

# ...
 File "<ipython-input-20-71acf7a39839>", line 23, in replacement
    return MyFunc.apply(x, y)
  File "<ipython-input-20-71acf7a39839>", line 10, in forward
    return torch.from_numpy(c)
TypeError: expected np.ndarray (got Proxy)

I understand that replace_pattern is not really usable in this case, as it will probably always fail to generate the replacement graph (I thought MyFunc would be opaque to Fx and just insert a MyFunc call). Not sure how it can be workaround. Any idea @ezyang ?