drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
560 stars 72 forks source link

Prediction without sampling #34

Closed Yarroudh closed 11 months ago

Yarroudh commented 11 months ago

Hello, Thanks for the great work. I trained a model on my own data. The results are good and I could run the inference on the test dataset. However, I noticed that we can't run the inference on data without sub-sampling, thus reducing the number of points. Is there any possibility to run inference without sampling and keep the original data points?

In my .yaml file I noticed there is a GridSampling operation:

# Preprocessing
pre_transform:
    - transform: DataTo
      params:
        device: 'cuda'
    - transform: GridSampling3D
      params:
        size: ${datamodule.voxel}
        hist_key: 'y'
        hist_size: ${eval:'${datamodule.num_classes} + 1'}

If I delete this transformation, I get this error:


Traceback (most recent call last):
  File "src/train.py", line 139, in main
    metric_dict, _ = train(cfg)
  File "/home/anass/superpoint_transformer/src/utils/utils.py", line 48, in wrap
    raise ex
  File "/home/anass/superpoint_transformer/src/utils/utils.py", line 45, in wrap
    metric_dict, object_dict = task_func(cfg=cfg)
  File "src/train.py", line 114, in train
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 932, in _run
    self._data_connector.prepare_data()
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 94, in prepare_data
    call._call_lightning_datamodule_hook(trainer, "prepare_data")
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 166, in _call_lightning_datamodule_hook
    return fn(*args, **kwargs)
  File "/home/anass/superpoint_transformer/src/datamodules/base.py", line 144, in prepare_data
    self.dataset_class(
  File "/home/anass/superpoint_transformer/src/datasets/base.py", line 193, in __init__
    super().__init__(root, transform, pre_transform, pre_filter)
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/torch_geometric/data/in_memory_dataset.py", line 57, in __init__
    super().__init__(root, transform, pre_transform, pre_filter, log)
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/torch_geometric/data/dataset.py", line 97, in __init__
    self._process()
  File "/home/anass/superpoint_transformer/src/datasets/base.py", line 493, in _process
    self.process()
  File "/home/anass/superpoint_transformer/src/datasets/base.py", line 528, in process
    self._process_single_cloud(p)
  File "/home/anass/superpoint_transformer/src/datasets/base.py", line 559, in _process_single_cloud
    nag = self.pre_transform(data)
  File "/home/anass/miniconda3/envs/spt/lib/python3.8/site-packages/torch_geometric/transforms/compose.py", line 24, in __call__
    data = transform(data)
  File "/home/anass/superpoint_transformer/src/transforms/transforms.py", line 23, in __call__
    return self._process(x)
  File "/home/anass/superpoint_transformer/src/transforms/partition.py", line 195, in _process
    assert d1.y.dim() == 2, \
AssertionError: Expected Data.y to hold `(num_nodes, num_classes)` histograms, not single labels

It appears that the Data.y tensor doesn't have the expected shape.

drprojects commented 11 months ago

Hi, thanks or your interest in the project.

Why using grid sampling ?

Voxelization is required to mitigate memory use and smoothen potential variations in point density. Removing it may blow your memory use on very dense point clouds (eg 3DIS). As a rule of thumb, you want your voxel resolution to be about twice as high as the characteristic dimension of your smallest object of interest (think Nyquist Shannon theorem). Rather than removing the voxelization altogether, I would recommend trying to reduce the voxel size to suit your needs instead.

Is there any possibility to run inference without sampling and keep the original data points?

See issue #9. For reasons explained above, the bulk of computation should probably not operate on full resolution. You can, however, convert the superpoint-level predictions to full-resolution output. As stated in the referenced issue, I have not had time to work on implementing this. It should be quite straightforward however, just need to keep track of per-voxel point indices, store them at preprocessing, and load them at inference if full-resolution is required. Pull requests are welcome 😉

Your actual error

AssertionError: Expected Data.y to hold `(num_nodes, num_classes)` histograms, not single labels

This happens because GridSampling not only voxelizes your data, but also does some data preparation. In particular, it converts integer point labels to voxel label histograms. To avoid this error, I would recommend keeping GridSampling but reducing the voxel size to your needs. You can make as small as you want, if you want to operate on your full resolution point cloud, though you may run into downstream memory issues...

Yarroudh commented 11 months ago

Thanks for your answers, I get it now. For the full-resolution generalization, is it possible to do this using SaveNodeIndex class to save the indices? I can work on that and make a pull request if I get good results.

@akharroubi

drprojects commented 11 months ago

Yes, but not only. Off the top of my head, to implement this feature, one would need to do the following:

Some steps (saving / loading) will require a good comprehension of the project structure.

drprojects commented 11 months ago

Closing this issue for now, which I consider a duplicate of the feature request in : https://github.com/drprojects/superpoint_transformer/issues/9

Feel free to re-open it if you work on implementing the above-described pipeline.

PS: if you are interested in this project, don't forget to give it a ⭐, it matters to us !