pytorch / extension-cpp

C++ extensions in PyTorch
1.02k stars 214 forks source link

How does the layer of C++ extensions translate to TorchScript or onnx? #72

Open yanglinxiabuaaa opened 3 years ago

yanglinxiabuaaa commented 3 years ago

test code is :

batch_size = 16 input_features = 32 state_size = 128

X = torch.randn(batch_size, input_features) h = torch.randn(batch_size, state_size) C = torch.randn(batch_size, state_size)

rnn = LLTM(input_features, state_size)

inputs = (X, (h, C))

traced = torch.jit.trace(rnn, inputs) print(traced.graph) torch.jit.save(traced, "lltm.pt")

graph(%self : torch.torch.nn.modules.module.Module, %input : Float(16, 32), %5 : (Float(16, 128), Float(16, 128))): %39 : Tensor = prim::GetAttrname="bias" %38 : Tensor = prim::GetAttrname="weights" %old_h : Float(16, 128), %old_cell : Float(16, 128) = prim::TupleUnpack(%5) %34 : (Tensor, Tensor) = ^LLTMFunction()(%input, %38, %39, %old_h, %old_cell) # /workspace/yanglinxia/CenterNet/torchscript/lltm-extension/LLTM.py:42:0 %35 : Float(16, 128), %36 : Float(16, 128) = prim::TupleUnpack(%34) %37 : (Float(16, 128), Float(16, 128)) = prim::TupleConstruct(%35, %36) return (%37)

Traceback (most recent call last): File "script.py", line 22, in torch.jit.save(traced, "lltm.pt") File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/jit/init.py", line 153, in save m.save(f, _extra_files=_extra_files) File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/jit/init.py", line 1626, in save return self._c.save(*args, **kwargs) RuntimeError: Could not export Python function call 'LLTMFunction'. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants: /workspace/yanglinxia/CenterNet/torchscript/lltm-extension/LLTM.py(42): forward /usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py(516): _slow_forward /usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py(530): call /usr/local/anaconda3/lib/python3.6/site-packages/torch/jit/init.py(1034): trace_module /usr/local/anaconda3/lib/python3.6/site-packages/torch/jit/init.py(882): trace script.py(19):