drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
508 stars 65 forks source link

Large variance in performance of re-trained models. #113

Closed biophase closed 1 month ago

biophase commented 1 month ago

Hello,

Thank you for your amazing work! SPT's light-weight architecture looks really promising for large remote sensing datasets due to its low latency and memory consumption. However, we're experiencing an issue with a high variance in performance when retraining a model using an identical config. Here's a short background of what we're trying to achieve:

We're currently testing SPT out on our own dataset of infrastructure assets like highway bridges and sign bridges. For this reason we adapted the dales datamodule config with the two modifications being:

Out of the box our results were slightly worse than other architectures, which SPT should outperform on benchmark datasets. With the current configuration we can achieve a validation mIoU of :

Architecture mIoU
PointNext-XL 71.67
KPConv 70.77
PTv3 76.11
SPT 70.38

So far this is good but it has led us to suspect that there is potentially some more hyperparamter tuning we can do to improve our results further. However, the problem is that even retraining the model with the same configuration leads to vastly different results, preventing us from attributing one hyperparameter change to a specific improvement. For example, the following is a list of sorted mIoU results we achieve in 13 runs with the config mentioned above (mIoU standard deviation of ~1.26 and a range of ca. 5%):

In [3]: mious
Out[3]: 
array([70.38345337, 69.77773285, 69.54872131, 68.86537933, 68.28470612,
       68.27809143, 68.10830688, 67.68677521, 67.30361176, 67.17458344,
       66.95844269, 66.50408173, 65.95107269])

In [4]: mious.mean()
Out[4]: 68.06345837

In [5]: mious.std()
Out[5]: 1.2676893110159888

In [6]: print(mious.min(), mious.max())
65.95107269 70.38345337

The values are reported at the end of training or by using the src/eval.py script. There doesn't seem to be a large difference if we re-eavluate the same model twice. Can you suggest how we can reach a point where the mIoU is stable between training runs? Or is this the expected behavior? If so, what could be the reason for this?

Thank you in advance

loicland commented 1 month ago

Hi,

Variance in performance is a deep learning-wide issue and not specific to SPT. However, it can be exacerbated by SP-based methods if the dataset is too small: by reducing the problem SPT also reduces the amount of supervision.

How large is your dataset? Can you post a picture of the partition? Do you have very rare classes or very small instances?

To combat this, I would advise to use a finer partition with smaller superpoints: lower pcp_regularization and cutoff for example.

biophase commented 1 month ago

Hello @loicland,

Thanks for the quick suggestions. I believe this has helped me to pinpoint what the issue could be. The variance is observed only on the validation data ( the training performance is overall constant throughout runs), so it could be that there is some overfitting happening on ambiguous classes.

Our dataset is indeed challenging with large differences in object sizes and total number of points. The size of the dataset could be considered moderate with a split of 48 training point clouds and 21 validation examples. Here's a screenshot of the distribution: | # | Class | Number of points | Number of objects | image

And here's I think the main limitation of our dataset - some point clouds were labeled before an additional feature 'intensity' was removed (since not all point clouds had it), leading to one of the classes (the side lane of the road) being ambiguous based on geometric features alone. This causes superpoints to bleed out into the actual road, which also shows in large variance in the class 'IfcRoad_roadside'.

ground truth prediction errors
image image image
partition level 1 partition level 2 partition level 3
image image image

Since posting the issue I pinned some hyperparameters on their highest mIoU respectively, which actually reduced the variance somewhat. Interestingly node_feat_jitter 0. --> 0.01 and h_edge_feat_jitter 0. --> 0.03 had the highest positive effect.

I still need to play around with the values more but for now this is what I am getting with 5 runs/value:

pcp_regularization: [0.1, 0.2, 0.3]     -->  70.078 mIoU (+/-0.683); max mIoU = 71.244
pcp_regularization: [0.1, 0.5, 1.5]     -->  70.307 mIoU (+/-0.582); max mIoU = 71.127
pcp_regularization: [0.05, 0.1, 0.2]    -->  69.900 mIoU (+/-0.725); max mIoU = 71.273
pcp_regularization: [0.2, 0.8, 2.0]     -->  70.156 mIoU (+/-0.626); max mIoU = 70.972

I consider this issue solved since the variance is now notably lower.

But I would still like to ask you - regarding pcp_regularization and the corasness of the graph: Does it in any way affect the receptive field of the network,e.g. fine superpoints = small receptive field? Or is the receptive field global - e.g. each superpoint at $\mathcal{P}_1$ attends to all other superpoints at the same level regardless of how many superpoints there are? Eq. (3) from your paper would suggest that self-attention is applied on the immidiate neighbours in the graph only (but perhaps my understanding is wrong)?

drprojects commented 1 month ago

Hi @biophase, thank you for the detailed feedback !

Class imbalance

Indeed, heavy class imbalance will affect generalization performance and introduce variance in the results since a small change of prediction on a rare class may have a dramatic impact in the metrics. This is not specific to SPT, but the superpoint partition might exacerbate this behavior for classes of small, under-represented objects (eg bridge pier, bride railing, footing, traffic sign, ...).

Radiometric features

As you identified, radiometric information can be crucial to separate some classes that would otherwise not be distinguishable based on geomeetry alone. Again, this is not a SPT-specific problem, any 3D processing method would struggle identifying a poster in a wall (or lane on the road) based on geometry alone.

Depending on the ratio of your points having an intensity (and whether you expect to have access to in your typical inference scenario), it may be worth using it anyways and setting it to 0 when not available. You can see the DALES dataset implementation for how to use intensity. Also, you can play with node_feat_drop and node_feat_jitter in your datamodule config to introduce some augmentations on the point features and see if that helps.

Augmentations

Indeed node_feat_jitter and h_edge_feat_jitter are augmentations that may be worth exploring for your own dataset. You can tune many different augmentations in datamodule configs:

# Augmentations parameterization
pos_jitter: 0.05
tilt_n_rotate_phi: 0.1
tilt_n_rotate_theta: 180
anisotropic_scaling: 0.2
node_feat_jitter: 0
h_edge_feat_jitter: 0
v_edge_feat_jitter: 0
node_feat_drop: 0
h_edge_feat_drop: 0.3
v_edge_feat_drop: 0
node_row_drop: 0
h_edge_row_drop: 0
v_edge_row_drop: 0
drop_to_mean: False

but also for batch construction:

# Batch construction parameterization
sample_segment_ratio: 0.2
sample_segment_by_size: True
sample_segment_by_class: False
sample_point_min: 32
sample_point_max: 128
sample_graph_r: 50  # set to r<=0 to skip SampleRadiusSubgraphs
sample_graph_k: 4
sample_graph_disjoint: True
sample_edge_n_min: -1  # [5, 5, 15]
sample_edge_n_max: -1  # [10, 15, 25]

sample_segment_ratio, sample_graph_r, sample_graph_k, might have an impact on performance (but also on GPU memory). sample_segment_by_class might also be worth a try for your class imbalance problem.

Receptive field

Each superpoint attends to neighboring superpoints in the same partition level. The neighborhood is constructed at preprocessing time with RadiusHorizontalGraph. You can play with graph_gap in your datamodule config, this rules the minimum pointwise distance for two superpoints to be neighbors (radius NN search but for point sets, described in our paper's appendix). Be aware that increasing the number of neighbors of a superpoint will increase the number of edges in the attention graph, which may in turn impact compute and memory. graph_k_max can help you keep the number of edges per superpoint in check. Also, max_num_nodes and max_num_edges can be adjusted if need be.