microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.64k stars 326 forks source link

Support model in_chans not equal to pre-trained weights in_chans #2289

Open adamjstewart opened 2 weeks ago

adamjstewart commented 2 weeks ago

Summary

If a user specifies in_chans and weights, and weights.meta['in_chans'] differs from in_chans, the user-specified argument should take precedence and weights should be repeated, similar to how timm handles pre-trained weights.

Rationale

When working on change detection, it is common to take two images and stack them along the channel dimension. However, this makes it impossible to use our pre-trained weights. Ideally, I would like to support something like:

from torchgeo.models import ResNet50_Weights, resnet50

model = resnet50(in_chans=4, weights=ResNet50_Weights.SENTINEL1_ALL_MOCO)

Here, the weights have 2 channels (HH and HV), while the dataset and model will have 4 channels (HH, HV, HH, HV).

Implementation

https://timm.fast.ai/models#Case-2:-When-the-number-of-input-channels-is-not-1 describes the implementation that timm uses. This can be imported as:

from timm.models.helpers import load_pretrained

We should make use of this in all of our model definitions instead of model.load_state_dict.

Alternatives

There is some ongoing work to add a ChangeDetectionTask that may split each image into a separate sample key. However, there will always be models that require images stacked along the channel dimension, so I don't think we can avoid supporting this use case.

Additional information

No response

keves1 commented 6 days ago

I'm interested in contributing to this, could I have this assigned to me? I looked into the load_pretrained method and found that it will only copy the weights if in_chans of the weights is 3, otherwise it uses random init for the first conv layer (see link to adapt_input_conv() which is called by load_pretrained())

https://github.com/huggingface/pytorch-image-models/blob/ee5b1e8217134e9f016a0086b793c34abb721216/timm/models/_manipulate.py#L256-L278

So this would need to be adapted since there are a variable number of input channels for the weights we would use.