NVIDIA / MinkowskiEngine

Minkowski Engine is an auto-diff neural network library for high-dimensional sparse tensors
https://nvidia.github.io/MinkowskiEngine
Other
2.47k stars 367 forks source link

MinkowskiEngine v0.5.4 have different kernel_map results compared to v0.4.3 (e.g. HYPER_CROSS) #436

Open Karbo123 opened 2 years ago

Karbo123 commented 2 years ago

Thanks for creating this great work. The recently released MinkowskiEngine V0.5.4 definitely improves the overall usage and brings some useful functions to us. However, it would be great if the newest release (v0.5.4) could have some additional improvements, detailed below.

descriptions I find the newest version (v0.5.4) seems to have some bugs for other kernel types. Although the exposed API allows user to specify the kernel type (i.e. HYPER_CUBE, HYPER_CROSS, and CUSTOM), the internal implementation only expectedly calculate the cubic kernel for the kernel_map. For example, the cross-shaped kernel will generate a wrong kernel_map.

to reproduce this code is to reproduce this limitation for ME v0.5.4 (only tested for HYPER_CROSS): NOTE: the code is tested for the current latest commit (f3c5544ce)

import torch
from MinkowskiEngine import KernelGenerator, CoordinateManager
from MinkowskiEngineBackend._C import RegionType, CoordinateMapType

lin = torch.arange(10)
coords_in = torch.cat([torch.zeros([1000, 1], dtype=torch.long), 
                       torch.stack(torch.meshgrid(lin, lin, lin), dim=-1).reshape(-1, 3)], dim=1).int().cuda()
coords_out = torch.tensor([[0, 5, 5, 5]]).int().cuda()

cm = CoordinateManager(D=3, coordinate_map_type=CoordinateMapType.CUDA)
ck_in, _ = cm.insert_and_map(coordinates=coords_in, tensor_stride=1, string_id="input")
ck_out, _ = cm.insert_and_map(coordinates=coords_out, tensor_stride=1, string_id="output")

kernel_kwargs = dict(kernel_size=3, region_type=RegionType.HYPER_CROSS)
ind = torch.cat(list(cm.kernel_map(ck_in, ck_out, **kernel_kwargs).values()), dim=1)[0]
diff = coords_in[ind.long()] - coords_out
print(diff)
print(f"center = {diff.float().mean()} (should be zero)")

the above codes print:

tensor([[ 0, -1,  1, -1],
        [ 0,  1,  0, -1],
        [ 0,  0, -1, -1],
        [ 0, -1, -1, -1],
        [ 0,  1, -1, -1],
        [ 0, -1,  0, -1],
        [ 0,  0,  0, -1]], device='cuda:0', dtype=torch.int32)
center = -0.3571428656578064 (should be zero)

and I also compile codes for a previous ME, another code is written for ME v0.4.3 (only tested for HYPER_CROSS):

import torch
import MinkowskiEngine as ME
from MinkowskiEngine import SparseTensor

lin = torch.arange(10)
coords_in = torch.cat([torch.zeros([1000, 1], dtype=torch.long), 
                       torch.stack(torch.meshgrid(lin, lin, lin), dim=-1).reshape(-1, 3)], dim=1).float().cuda()
coords_out = torch.tensor([[0, 5, 5, 5]]).float().cuda()

sin = SparseTensor(feats=torch.randn(1000, 16), coords=coords_in)
sout = SparseTensor(feats=torch.randn(1, 32), coords=coords_out, force_creation=True, coords_manager=sin.coords_man)
ck_in = sin.coords_key
ck_out = sout.coords_key
cm = sin.coords_man

kernel_kwargs = dict(kernel_size=3, region_type=1)
ind = torch.tensor(cm.get_kernel_map(ck_in, ck_out, **kernel_kwargs))[0]
diff = coords_in[ind.long()] - coords_out
print(diff)
print(f"center = {diff.float().mean()} (should be zero)")

however, the older version (v0.4.3) prints message expectedly:

tensor([[ 0.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.],
        [ 0., -1.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0., -1.,  0.],
        [ 0.,  0.,  0.,  1.],
        [ 0.,  0.,  0., -1.]], device='cuda:0')
center = 0.0 (should be zero)

This shows that the version updates may possibly change this unexpectedly.

expected behavior the latest version v0.5.4 has different kernel_map behavior compared to the old version v0.4.3. the printed message of v0.5.4 should have center = 0.0

possible solutions It seems that the coordinate_at function defined here only implements cubic kernels. maybe we can add supports to other kernel types similarly following the previous implementations, but I am not sure.

It would be great if some efforts are on it. If there is something wrong, please let me know. Looking forward to this bug fix.