drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
546 stars 71 forks source link

fine-tuning on pretrained weights #70

Closed gvoysey closed 6 months ago

gvoysey commented 6 months ago

Hi Damien,

Thanks for this paper and project! Do you have docs or guidance for fine-tuning on pretrained weights with additional data that has different classes than the base weights? I don't see anything in the docs at a first glance.

Having tried this fairly naively on DALES weights, i get errors like Sizes of tensors must match except in dimension 0. Expected size 9 but got size 11 for tensor 2 in the list.

drprojects commented 6 months ago

Hi @gvoysey, thanks for your interest in the project !

We do not provide any specific off-the-shelf script for fine-tuning, but we built the entire project on pytorch lightning, which comes with a lot of utilities, I invite you to have a look.

However, if your dataset has different classes than DALES, you will need to create your own dataset (see docs). From there, the easy solution would be to train from scratch if you have enough data. Otherwise, you can do more advanced stuff like removing the last classification layer, retraining it from scratch while freezing the rest of the model. See the lightning docs for that:

gvoysey commented 6 months ago

Thanks for the pointers!

I'd like to avoid retraining from scratch every time, though I do anticipate having at least a DALES amount of data at hand. I'm coming from an image object detection background, primarily, so "do output layer surgery and fine-tune on smaller amounts of data than the base weights" is the SOP i was expecting -- is it very different in the LIDAR world?

drprojects commented 6 months ago

Glad to hear you are familiar with fine-tuning strategies then ! No, the fine-tuning recipes are the same for 3D and 2D. As usual in deep learning, there is no perfect recipe. So you might need to try various strategies and see what works best for your:

One thing to keep in mind though: similar to other transformer-based works before, we have found that using smaller learning rates on the self-attention blocks than for the rest of the architecture helps stabilize training. So you might want to take this into account when building your schedulers. See where transformer_lr_scale is involved in the code to get an idea :wink: