drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
545 stars 71 forks source link

In custom dataset where point number in classes are quite different #62

Closed yudopan closed 7 months ago

yudopan commented 7 months ago

Greetings, Thank you very much for your fantastic work!

I have a question when applying it to a custom dataset for semantic segmentation. In my dataset, the number of points in different classes are very different: like class A has 100thousands and class B has 100 points. In the training process, I want to update the training weights by considering all classes equally. Could you please suggest what I could do? Thank you very much! Best regards, Yudopan

drprojects commented 7 months ago

Hi @yudopan, thanks for your interest in our project !

There are two ways you can tackle class imbalance: setting weights in your loss to give more importance to rare classes, or sampling your data during training.

For the first option, setting:

weighted_loss: True

in the model config will automatically call the get_class_weight() method of your dataset upon training start. This will compute the frequency of each class in your dataset and will set the loss' weights accordingly. I invite you to have a look at the code for more details, it is fairly commented. Note that weighted_loss=True for all models by default in the provided code.

For the second option, you can play with the by_class=True option of the SampleSegments and SampleRadiusSubgraphs transforms. By default, this behavior is not activated. Here again, I invite you to have a look at the code for these transforms to understand what they do and what the class-weighted sampling will do.

That being said, if one of your class only has 100 points, I fear it may be very hard not to overfit to the training examples. For instance, in KITTI-360 some under-represented classes such at traffic light or motorcycle are very hard to capture despite having multiple instances of hundreds of points in the dataset. I do not know your dataset, but maybe you can find a larger dataset to pretrain on, that contains classes similar to your classes of interest.

Good luck ! Damien