Open yinpeiqi opened 2 years ago
FWIW, I get an error the replace function when using numpy in custom operator
# load whatever model with a `matmul`
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
# force to go through ext lib, np here
c = a.cpu().detach().numpy() + b.cpu().detach().numpy()
return torch.from_numpy(c)
def pattern(x, y):
return torch.matmul(x, y)
def replacement(x, y):
return MyFunc.apply(x, y)
replace_pattern(fx_graph, pattern, replacement) # <- crash
produces:
# ...
File "<ipython-input-20-71acf7a39839>", line 23, in replacement
return MyFunc.apply(x, y)
File "<ipython-input-20-71acf7a39839>", line 10, in forward
return torch.from_numpy(c)
TypeError: expected np.ndarray (got Proxy)
I understand that replace_pattern
is not really usable in this case, as it will probably always fail to generate the replacement graph (I thought MyFunc would be opaque to Fx and just insert a MyFunc
call).
Not sure how it can be workaround.
Any idea @ezyang ?
š Describe the bug
When we are trying to use torch.fx to trace functions in other libs, I met the below failure:
The output is:
The '.' was replaced by '_' in the custom functions. For this issue, I checked the relative functions, and I think in torch/fx/graph.py, line 314:
This issue is caused by this 'add_global' function. Could we change here to support import function from other libs?
Versions
github master branch
cc @ezyang @SherlockNoMad