allenai / satlaspretrain_models

Apache License 2.0
52 stars 11 forks source link

Multi Image SatlasNet #9

Open PumeTu opened 2 weeks ago

PumeTu commented 2 weeks ago

Hello,

Thank you for your work with the satlaspretrain models. I've been trying to use the multi image feature of the pretrained models and have been wondering how to recreate the SatlasNet architecture. The thing I am stuck on is the input shape where it is expected that the input channel is 3 for ResNet and Swin, would this mean that we would have to write the forward pass ourselves instead of using the weightmanager and using the command get_pretrained_model(model_id, fpn=True, head=Classifier) function?

favyen2 commented 2 weeks ago

The RGB model inputs 3 channels per image while the multi-spectral model inputs 9 channels per image, see https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images for the list of bands we use.

If you want to use a different number of inputs, but still use the pre-trained model, then you can call get_pretrained_model and then override the input layer similar to how we override the default in torchvision at https://github.com/allenai/satlaspretrain_models/blob/225473ce01fca5898a6e2bf13a31475d8f5bb8ac/satlaspretrain_models/models/backbones.py#L27. For the multi-image models, you would also need to override the image_channels field of the AggregationBackbone module.

If you don't want to use pre-trained weights, then you can pass the num_channels argument to the satlaspretrain_models.model.Model class.