Helmholtz-AI-Energy / TBBRDet

Thermal Bridges on Building Rooftops Detection (TBBRDet)
BSD 3-Clause "New" or "Revised" License
16 stars 6 forks source link

How to modify pretrained weighted file : mask_rcnn_r50_fpn_mstrain-poly_3x_coco_20210524_201154-21b550bb.pth #4

Open farzinnikkhah opened 6 months ago

farzinnikkhah commented 6 months ago

Hi,

I am currently working on a project using the MMDetection framework, I try to train a Mask R-CNN model with the TBBRV2 dataset, which has 5-channel images. I encountered an issue when trying to adapt the standard 3-channel pretrained model for this purpose.

During training, I ran into the following error: RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 2, 5, 2688, 3392]

I noticed that in your projects, you have used a modified pretrained file (e.g., "mask_rcnn_r50_fpn_mstrain-poly_3x_coco_20210524_201154-21b550bb_truncated.pth") which might have been adapted for similar scenarios. Could you provide guidance on how to modify the pretrained model file to be compatible with 5-channel input? Any advice or pointers towards resources or documentation would be greatly appreciated.

Thank you for your assistance!

emvollmer commented 5 months ago

Hi there,

If I understand correctly, you're having trouble recreating the training of the Mask R-CNN model with pretrained weights on the TBBRv2 5-channel dataset. The way this works is actually not through modified weights files - I know this line seems to indicate that, but for f.e. the swin-t version you'll notice the standard weights work fine. You simply need to run training with the correct config file .pretrained.py and define the load_from parameter as the path to the standard, downloaded weights file. The custom numpy file loader and in_channels=5 definition handle the 5-channel inputs.

In my fork of this repo, I've detailed how you can train with pretrained weights. The example train command there is for 5-channel inputs. A small forewarning: I define most parameters (such as load_from) directly via the command line to simplify matters, but I amended the scripts to be able to do so - this might not work with these ones here. I'm not sure exactly what you commands you tried out that returned that RuntimeError but I hope this points you in the right direction. If not, please share some more specific details (i.e. command(s), system setup etc).

Cheers!