traveller59 / spconv

Spatial Sparse Convolution Library
Apache License 2.0
1.85k stars 362 forks source link

Unable to transform SparseConv3d with torch.jit.trace. #572

Open sun-sey opened 1 year ago

sun-sey commented 1 year ago

I've seen that some issues can be resolved by choosing the appropriate ConvAlgo. However, I have tried running all kinds of ConvAlgo but it didn't work out.

The conversion of SubMConv3d works fine. But when downsampling and upsampling, torch.jit.trace doesn't seem to track it. What should I do?

The following error occurs.

spconv-cu117 == 2.2.6 torch == 1.12.1

File "/home/ssy/Projects/python_spconv/networks.py", line 217, in traced_model = torch.jit.trace(model, [pt_fea, idx, sparse_shape]) File "/home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py", line 759, in trace return trace_module( File "/home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py", line 1001, in trace_module _check_trace( File "/home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py", line 534, in _check_trace traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace") File "/home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py", line 449, in run_mod_and_filter_tensor_outputs raise TracingCheckError( torch.jit._trace.TracingCheckError: Tracing failed sanity checks! encountered an exception while running the trace with test inputs. Exception: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/conv.py(440): forward /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/nn/modules/module.py(1182): _slow_forward /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/nn/modules/module.py(1194): _call_impl /home/ssy/Projects/python_spconv/networks.py(199): forward /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/nn/modules/module.py(1182): _slow_forward /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/nn/modules/module.py(1194): _call_impl /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py(976): trace_module /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py(759): trace /home/ssy/Projects/python_spconv/networks.py(217): RuntimeError: RuntimeError: /io/build/temp.linux-x86_64-cpython-39/spconv/build/core_cc/src/csrc/sparse/convops/spops/ConvGemmOps/ConvGemmOps_indice_conv.cc(129) nhot_profile > 0 assert faild. this shouldn't happen

    At:
      /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/ops.py(854): indice_conv
      /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/functional.py(92): forward
      /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py(105): decorate_fwd
      /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py(443): run_mod_and_filter_tensor_outputs
      /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py(534): _check_trace
      /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/autograd/grad_mode.py(27): decorate_context
      /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py(1001): trace_module
      /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/torch/jit/_trace.py(759): trace
      /home/ssy/Projects/python_spconv/networks.py(217): <module>

ERROR: Graphs differed across invocations! Graph diff: graph(%self.1 : torch.SparseNetworks, %pt_fea : Tensor, %xy_ind : Tensor, %sparse_shape : Tensor): %spconv_layer4 : torch.spconv.pytorch.conv.SubMConv3d = prim::GetAttrname="spconv_layer4" %spconv_layer3 : torch.spconv.pytorch.conv.SparseInverseConv3d = prim::GetAttrname="spconv_layer3" %spconv_layer2 : torch.spconv.pytorch.conv.SparseConv3d = prim::GetAttrname="spconv_layer2" %spconv_layer : torch.spconv.pytorch.conv.SubMConv3d = prim::GetAttrname="spconv_layer" %103 : int = prim::Constant[value=1](), scope: module.spconv_layer # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/conv.py:324:0 %104 : int = prim::Constant[value=128](), scope: module.spconv_layer # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/conv.py:321:0 %105 : int = prim::Constant[value=3](), scope: module.spconv_layer # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/conv.py:321:0 %bias.1 : Tensor = prim::GetAttrname="bias" %weight.1 : Tensor = prim::GetAttrname="weight" %108 : int[] = prim::ListConstruct(%105, %104), scope: module.spconv_layer %109 : Tensor = aten::view(%weight.1, %108), scope: module.spconv_layer # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/conv.py:321:0 %features.1 : Tensor = aten::mm(%pt_fea, %109), scope: __module.spconv_layer # /home/ssy/anaconda3/envs/torchcpp/lib/python3.9/site-packages/spconv/pytorch/conv.py:319:0 %features.3 : Tensor = aten::add(%features.1, %bias.1, %103), scope: module.spconv_layer # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/conv.py:324:0 %112 : int = prim::Constant[value=1](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/ops.py:170:0 %113 : int = prim::Constant[value=0](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/ops.py:170:0 %114 : int = prim::Constant[value=4](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:139:0 %115 : int = prim::Constant[value=351](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:139:0 %116 : bool = prim::Constant[value=0](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:156:0 %117 : Device = prim::Constant[value="cuda:0"](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:156:0 %118 : NoneType = prim::Constant(), scope: module.spconv_layer2 %119 : int = prim::Constant[value=3](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:156:0 %120 : int = prim::Constant[value=-1](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:156:0 %121 : int = prim::Constant[value=100](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:156:0 %122 : int = prim::Constant[value=27](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:156:0 %123 : int = prim::Constant[value=2](), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:156:0 %bias.3 : Tensor = prim::GetAttrname="bias" %weight.3 : Tensor = prim::GetAttrname="weight" %126 : int[] = prim::ListConstruct(%123, %122, %121), scope: module.spconv_layer2 %ten.1 : Tensor = aten::full(%126, %120, %119, %118, %117, %116), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:156:0 %128 : int[] = prim::ListConstruct(%122), scope: module.spconv_layer2 %ten.3 : Tensor = aten::zeros(%128, %119, %118, %117, %116), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:124:0 %130 : int[] = prim::ListConstruct(%115, %114), scope: module.spconv_layer2 %ten.13 : Tensor = aten::empty(%130, %119, %118, %117, %116, %118), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/cppcore.py:139:0 %outids : Tensor = aten::slice(%ten.13, %113, %113, %115, %112), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/ops.py:170:0 %133 : int = aten::size(%outids, %113), scope: module.spconv_layer2 # /home/ssy/anaconda3/envs/torch_cpp/lib/python3.9/site-packages/spconv/pytorch/conv.py:445:0 %134 : Tensor = prim::NumToTensor(%133), scope: __module.spconv_layer2

Here is the code.

import spconv.pytorch as spconv import torch.nn as nn

class SparseNetworks(nn.Module): def init(self): super().init() self.spconv_layer = spconv.SubMConv3d(3, 128, indice_key='sp', kernel_size=1, stride=1, padding=0, algo=spconv.ConvAlgo.Native) self.spconv_layer2 = spconv.SparseConv3d(128, 128, kernel_size=3, stride=2, padding=1, indice_key='spconv', algo=spconv.ConvAlgo.Native) self.spconv_layer3 = spconv.SparseInverseConv3d(128, 128, kernel_size=3, indice_key='spconv', algo=spconv.ConvAlgo.Native) self.spconv_layer4 = spconv.SubMConv3d(128, 128, indice_key='sp2', kernel_size=1, stride=1, padding=0, algo=spconv.ConvAlgo.Native)

def forward(self, pt_fea, xy_ind, sparse_shape):

    ret = spconv.SparseConvTensor(pt_fea, xy_ind, sparse_shape, batch_size=1)
    x = self.spconv_layer(ret)
    x = self.spconv_layer2(x)
    x = self.spconv_layer3(x)
    x = self.spconv_layer4(x)

    # return ret.features, ret.indices
    return ret.dense().mean()

if name == "main": model = SparseNetworks().eval().cuda() pt_fea = torch.randn((100, 3)).type(torch.FloatTensor).cuda() sparse_shape = torch.Tensor([200, 200, 200]).cuda() idx = torch.randint(low=0, high=99, size=(100, 3)).type(torch.FloatTensor).cuda() batch_size = torch.zeros((100,1)).type(torch.FloatTensor).cuda()

idx = torch.cat((batch_size, idx), dim=1).int()
print('forward result is ', model(pt_fea, idx, sparse_shape))
traced_model = torch.jit.trace(model, [pt_fea, idx, sparse_shape])
traced_model.save('my_model.pt')
yangqifan913 commented 7 months ago

same problem

First diverging operator: Node diff: