drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
560 stars 72 forks source link

RANSAC Error on Custom Dataset #32

Closed mbendjilali closed 11 months ago

mbendjilali commented 11 months ago

Hello, I've followed the documentation guidelines to train SPT on my custom dataset, but I fail to go through the training step as a ValueError raised by ransac emerges on one of the point clouds.

Error executing job with overrides: ['experiment=HAG', 'ckpt_path=../HAG.ckpt']
Traceback (most recent call last):
  File "src/train.py", line 139, in main
    metric_dict, _ = train(cfg)
  File "/data/Moussa/superpoint_transformer/src/utils/utils.py", line 48, in wrap
    raise ex
  File "/data/Moussa/superpoint_transformer/src/utils/utils.py", line 45, in wrap
    metric_dict, object_dict = task_func(cfg=cfg)
  File "src/train.py", line 114, in train
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 932, in _run
    self._data_connector.prepare_data()
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 94, in prepare_data
    call._call_lightning_datamodule_hook(trainer, "prepare_data")
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 166, in _call_lightning_datamodule_hook
    return fn(*args, **kwargs)
  File "/data/Moussa/superpoint_transformer/src/datamodules/base.py", line 144, in prepare_data
    self.dataset_class(
  File "/data/Moussa/superpoint_transformer/src/datasets/base.py", line 200, in __init__
    super().__init__(root, transform, pre_transform, pre_filter)
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/torch_geometric/data/in_memory_dataset.py", line 57, in __init__
    super().__init__(root, transform, pre_transform, pre_filter, log)
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/torch_geometric/data/dataset.py", line 97, in __init__
    self._process()
  File "/data/Moussa/superpoint_transformer/src/datasets/base.py", line 511, in _process
    self.process()
  File "/data/Moussa/superpoint_transformer/src/datasets/base.py", line 546, in process
    self._process_single_cloud(p)
  File "/data/Moussa/superpoint_transformer/src/datasets/base.py", line 577, in _process_single_cloud
    nag = self.pre_transform(data)
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/torch_geometric/transforms/compose.py", line 24, in __call__
    data = transform(data)
  File "/data/Moussa/superpoint_transformer/src/transforms/transforms.py", line 23, in __call__
    return self._process(x)
  File "/data/Moussa/superpoint_transformer/src/transforms/point.py", line 223, in _process
    ransac = RANSACRegressor(random_state=0, residual_threshold=1e-3).fit(
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/sklearn/base.py", line 1152, in wrapper
    return fit_method(estimator, *args, **kwargs)
  File "/home/moussa/miniconda3/envs/spt/lib/python3.8/site-packages/sklearn/linear_model/_ransac.py", line 358, in fit
    raise ValueError(
ValueError: `min_samples` may not be larger than number of samples: n_samples = 2.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Could you help me on this issue ?

mbendjilali commented 11 months ago

From what I can understand, the GroundElevation process is looking for points under a certain threshold to match a plane to the ground. The error arises from the fact that within a certain point cloud, only two points validate the threshold condition.

mbendjilali commented 11 months ago

By applying a multiplicative factor to all dimensions of the point cloud, I was able to bypass this issue.

drprojects commented 11 months ago

Hi, yes RANSAC is used to find the ground/floor as a roughly planar structure, in order to compute the point elevation feature. This is a heuristic and the parameterization depends on your dataset (see setting for indoor in S3DIS configs, and outdoor settings in KITTI-360 and DALES configs). I am not claiming that these are universal and you will need to adapt this to your dataset. It is possible that you do not even have such thing as a floor or ground. In which case you may also remove the "elevation" altogether from the partition_hf and point_hf in your config.