NVIDIA / MinkowskiEngine

Minkowski Engine is an auto-diff neural network library for high-dimensional sparse tensors
https://nvidia.github.io/MinkowskiEngine
Other
2.47k stars 367 forks source link

merge_sort: failed to synchronize: cudaErrorIllegalAddress with TensorField during training on some ScanNet scans #395

Open alexsr opened 3 years ago

alexsr commented 3 years ago

Hello, first and foremost, thank you for creating this library! I may have found a bug, but I also found a workaround / solution and thought I'd share it here.

TL;DR: When using TensorField while training, sometimes the back-prop through the inverse mapping fails. To get around this issue, use SparseTensor instead and only do inverse mapping for predictions using the map from sparse_collate.


Describe the bug I am using MinkowskiEngine 0.5.4 with Pytorch 1.9.0 and Cuda 11.1 in my personal project on the ScanNet dataset. While training I have run into the following error a few times now:

RuntimeError: merge_sort: failed to synchronize: cudaErrorIllegalAddress: an illegal memory access was encountered

This error did not occur while running predictions. I also looked at the issues #283 and #299. However, the current Readme states, that the issues, that might have caused these errors, are resolved in the current version of the library. Running my script with with torch.autograd.detect_anomaly() I was able to find the location that caused the error:

[W python_anomaly_mode.cpp:104] Warning: Error detected in IndexBackward. Traceback of forward call that caused the error:
[...]
  File "/root/miniconda3/envs/myproject/lib/python3.8/site-packages/MinkowskiEngine/MinkowskiSparseTensor.py", line 599, in slice
    self.F[X.inverse_mapping(self.coordinate_map_key).long()],

To Reproduce Here are the relevant parts of my code that are used for collation, creation of the TensorField, the forward method of my model (MinkUNet34C) and the loss computation. The current state of my framework is quite complicated, therefore I cannot really share the complete code. But this bug should also probably occur when using the ScanNet demo for training. Now, this works perfectly fine for most ScanNet scans. However, there are a few, e.g. scene_0029_00 where the crash occurs. It seems that there is an issue with backpropagation through IndexBackward.

labels = torch.from_numpy(np.concatenate(labels, 0).astype(dtype=np.int32)).int()
feats = normalize_color(torch.from_numpy(np.concatenate(feats, 0)).float())
coords = ME.utils.batched_coordinates([c for c in coords], dtype=torch.float32)
[...]
minknet_input = ME.TensorField(coordinates=coords, features=feats,
    quantization_mode=self.quantization_mode,
     minkowski_algorithm=self.minkowski_algorithm,
     device=self.device)
[...]
def forward(self, x):
   x_sparse = x.sparse()
   o_sparse = self.model(x_sparse)
   out = o_sparse.slice(x)
   return out
[...]
_, pred = out.F.max(1)
pred_batched = [pred[row_idx] for row_idx in out.decomposition_permutations if row_idx.shape[0] != 0]
labels_out = out.F.float()
loss = self.criterion(labels_out, labels)

I was able to get around this error by using SparseTensor, sparse_collate to get the mappings and inverse mappings, and doing the inverse mapping myself in case I needed the prediction output.

These snippets are supposed to replace the equivalent snippets in the upper code sample:

coords_mapped, mappings, inv_mappings = zip(*[ME.utils.sparse_quantize(coordinates=c, quantization_size=self.quantization_size,
                                             return_index=True, return_inverse=True) for c in coords])
feats_mapped = [c[mappings[i]] for i, c in enumerate(feats)]
labels_mapped = [c[mappings[i]] for i, c in enumerate(labels)]

labels = torch.from_numpy(np.concatenate(labels_mapped, 0).astype(dtype=np.int32)).int()
feats = normalize_color(torch.from_numpy(np.concatenate(feats_mapped, 0)).float())
coords = ME.utils.batched_coordinates([c for c in coords_mapped], dtype=torch.int32)
[...]
minknet_input = ME.SparseTensor(feats.float(), coords)
[...]
def forward(self, x):
   return self.model(x)
[...]
_, pred = out.F.max(1)
pred_batched = [pred[row_idx] for row_idx in out.decomposition_permutations if row_idx.shape[0] != 0]
pred_batched = [p[inv_mappings] for p in pred_batched]
labels_out = out.F.float()
loss = self.criterion(labels_out, labels)

==========System==========
Linux-5.4.0-81-generic-x86_64-with-glibc2.17
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.2 LTS"
3.8.11 (default, Aug  3 2021, 15:09:35) 
[GCC 7.5.0]
==========Pytorch==========
1.9.0
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 470.57.02
CUDA Version 11.4
VBIOS Version 86.07.59.00.70
Image Version N/A
GSP Firmware Version N/A
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Wed_Jul_14_19:41:19_PDT_2021
Cuda compilation tools, release 11.4, V11.4.100
Build cuda_11.4.r11.4/compiler.30188945_0
==========CC==========
/usr/bin/c++
c++ (Ubuntu 7.5.0-6ubuntu2) 7.5.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.4
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 11010
CUDART version MinkowskiEngine is compiled: 11010

Additional context As an aside, the example in the Training Tutorial shows this code snippet:

mapping = ME.utils.sparse_quantize(
            coords=coords,
            return_index=True)

when it should be:

coords_mapped, mapping = ME.utils.sparse_quantize(
            coordinates=coords,
            return_index=True)

The example in the function documentation is also not up to date here as it should return unique_coords or use return_maps_only=True as an arg:

        Example::
           >>> unique_map, inverse_map = sparse_quantize(discrete_coords, return_index=True, return_inverse=True)
wtiandong commented 3 years ago

Same here. When I run examples/classification_modelnet40.py, I get the same error at the same location. self.F[X.inverse_mapping(self.coordinate_map_key).long()]

My environment: CUDA 11.1 Pytorch 1.9 Ubuntu 18.04 ME: 0.5.4

to reproduce, just run examples/classification_modelnet40.py and don't touch anything. It falls when the training process runs the 307th iteration.

gitouni commented 2 years ago
pred_batched = [pred[row_idx] for row_idx in out.decomposition_permutations if row_idx.shape[0] != 0]
pred_batched = [p[inv_mappings] for p in pred_batched]

You can change the args --network from the default minkfcnn to minksplatfcnn to avoid this error. Actually, the ME.Tensorfield often cause GPU backward issues as the author said, but I'm a newbie to this fancy project so I have no idea how to modify minkfcnn network to fix ME.SparseTensor input.

To examplify my issue, here is the forward part of minkfcnn in classification_modelnet40.py

def forward(self, x: ME.Tensorfield):
        x = self.mlp1(x)
        y = x.sparse()

        y = self.conv1(y)
        y1 = self.pool(y)

        y = self.conv2(y1)
        y2 = self.pool(y)

        y = self.conv3(y2)
        y3 = self.pool(y)

        y = self.conv4(y3)
        y4 = self.pool(y)

        x1 = y1.slice(x)
        x2 = y2.slice(x)
        x3 = y3.slice(x)
        x4 = y4.slice(x)

        x = ME.cat(x1, x2, x3, x4)

        y = self.conv5(x.sparse())
        x1 = self.global_max_pool(y)
        x2 = self.global_avg_pool(y)

        return self.final(ME.cat(x1, x2)).F

If Tensorfield is changed to SparseTensor, the ME.cat operation will not work. However, use invert_mapping slice(x) to transferME.SparseTensor into ME.Tensorfield is not allowed in training phase. So I have no idea to fix this error in ME.cat operation.

gitouni commented 2 years ago
coords = ME.utils.batched_coordinates([c for c in coords_mapped], dtype=torch.int32)

Is the snippets about pre_batched is necessary for backward? (it is not accessed in prediction. Does it affect the backward process? )

alexsr commented 2 years ago

I used pred_batched so that I had both the predictions and the inputs as batches. That way I could do visualization or whatever on the batches. The code snippet is just meant to show how to recover batches if necessary.

The outputs used for the loss (backward step) don't have to be in batch format.

Remember, the code snippet I posted here is just a quick fix I came up with for the issues I faced while trying to use TensorField which did not work for me.

KarelZhang commented 11 months ago

I fixed this error by using torch1.10.0+cuda11.3

barzanisar commented 11 months ago

I have the same error

Eaphan commented 8 months ago

I have the same error too when I use the MinkowskiEngine with tag v0.5.4. Then I checkout commit 02fc608bea4c0549b0a7b00ca1bf15dee4a0b228 and re-install the MinkowskiEngine, and the error disappears.

bujiebuhuo commented 3 months ago

Can you release another version. This 0.5.4 is too misleading, it wasted a lot of my time due to this error. I think the latest release version is relatively stable, so I chose the latest release version Minkowski Engine 0.5.4.