alexandor91 / se3-equi-graph-registration

Equi-GSPR: Equivariant SE(3) Graph Network Model for Sparse Point Cloud Registration, ECCV Paper Code
MIT License
17 stars 1 forks source link

Mistakes in training #5

Open shangxiaqius opened 4 weeks ago

shangxiaqius commented 4 weeks ago

Hello,First of all, thank you very much for this research work! However, when I try to train with the 3DMatch dataset you provided, I get an error message: (torch) yy@yy:~/se3-equi-graph-registration/src$ python train_eval_egnn.py --mode train --base_dir /media/yy/Elements/registration_data/3DMatch_FCGF_Feature --num_epochs 50 --batch_size 2

Epoch 1/50 Traceback (most recent call last): File "train_eval_egnn.py", line 1201, in train_model(cross_attention_model, train_loader, val_loader, num_epochs=num_epochs, \ File "train_eval_egnn.py", line 969, in train_model train_loss = train_one_epoch(model, train_loader, optimizer, device, epoch, writer, use_pointnet, log_interval, beta) File "train_eval_egnn.py", line 696, in train_one_epoch graph_idx_0 = knn_graph(xyz_0, k=k, loop=False) File "/home/yy/anaconda3/envs/torch/lib/python3.8/site-packages/torch_cluster/knn.py", line 132, in knn_graph edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine, File "/home/yy/anaconda3/envs/torch/lib/python3.8/site-packages/torch_cluster/knn.py", line 81, in knn return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine, File "/home/yy/.local/lib/python3.8/site-packages/torch/ops.py", line 854, in call return self._op(*args, **(kwargs or {})) RuntimeError: x.dim() == 2 INTERNAL ASSERT FAILED at "csrc/cuda/knn_cuda.cu":93, please report a bug to PyTorch. Input mismatch

My pytorch version is 2.3.0 and CUDA is 12.1, is the error reason that my version is different from yours? Thank you very much.

alexandor91 commented 4 weeks ago

Are u using batch size bigger than one? The current knn graph cluster has some bug to use batch size > 1, so it is not supporting batch size bigger than one. If not, try Cosine Distance Option: If cosine=True or cosine=False is not explicitly provided, try specifying it to avoid potential issues: " edge_index = knn(x, x, k, batch=batch, cosine=False) " Or try update PyTorch Geometric and Dependencies: This could also be a version compatibility issue. Make sure your torch, torch-cluster, and torch-geometric libraries are up-to-date. You can update them using: " pip install --upgrade torch torch-cluster torch-geometric " We will commit the batch size bug fix version by the end of this month, as we found the larger batch can help accumulated gradient for a more stable training. The recommended pytorch geometry version is pyg (Pytorch-Geometric) 2.4.0, knn api may be adapted or remapped in new version.

shangxiaqius commented 3 weeks ago

ok! Thank you very much!

shangxiaqius commented 3 weeks ago

Hello, I have a question to bother you again. The following warning occurred during my training. Should I ignore this warning? Where should I modify this ratio?Thank you very much. Warning: Not enough sample points for the fixed number, sampling with repetitions. Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating samplinf will be used! Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating samplinf will be used! Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating samplinf will be used! Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating samplinf will be used! Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating samplinf will be used! Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating samplinf will be used! Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating samplinf will be used!

alexandor91 commented 3 weeks ago

These warnings are only used for helping debugging, it is normal, you can disable it in the dataloader file, as the scan pair sometimes are not enough to finad enough point pairs, thus repeating sampling of point pairs will be used.