IBM / terratorch

a Python toolkit for fine-tuning Geospatial Foundation Models (GFMs).
Apache License 2.0
131 stars 13 forks source link

Fine-tuning with tubelet_size greater than 1 #128

Open Foxigod opened 3 weeks ago

Foxigod commented 3 weeks ago

@CarlosGomes98 Describe the issue I'm interested in fine-tuning a ViT model with the patch-embedding size set to something greater than 1 for the temporal dimension. To do this, I at least have to perform something similar to #118. When running after this change, I get an error like the following:

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [80, 19200] but got: [80, 9216].

which seems to be sprouted from:

File "<PATH>/terratorch/terratorch/models/heads/classification_head.py", line 55, in forward
  x = self.head(x)

I believe the issue to be inconsistencies between the TemporalViTEncoder's prepare_features_for_image_model function and the calculation of these "num_chs" values in the TemporalViTEncoder, however I'm not fully certain.

Can I ask someone to explain to me this transformation that is going on in prepare_features_for_image_model? What is its purpose, and do you have any references explaining why it is formulated the way it is?

To Reproduce (optional, but appreciated) I'm performing a classification task with the following config entry with the aforementioned change like #118:

[...]
model:
  class_path: terratorch.tasks.ClassificationTask
  init_args:
    model_args:
      decoder: IdentityDecoder
      pretrained: true
      backbone: prithvi_vit_100
      backbone_pretrained_cfg_overlay:
        file: Prithvi_100M.pt
      backbone_patch_size: 3
      backbone_pretrain_img_size: 15
      backbone_tubelet_size: 3
[...]
CarlosGomes98 commented 2 weeks ago

Hi!

This looks like it was an issue of confusion with the (admittedly numerous) configuration parameters for the backbone. Let me try to clarify in the following:

num_frames

This parameter controls the number of frames (the length of the temporal sequence) that the Prithvi ViT model expects. The batch should have the dimensions in the order Batch, Channels, Time, Height, Width, where Time should be = num_frames. num_frames=1 is very common, so it is treated as a special case - if Time = 1, this dimension will be automatically created if it does not exist: https://github.com/IBM/terratorch/blob/ffc431a09ab8bdafeb184e6668e4450b264597c2/terratorch/models/backbones/vit_encoder_decoder.py#L433-L434

tubelet_size

This is a more advanced parameter. It determines the depth of the 3D convolution used to for patch embedding in this transformer. In the majority of cases, this should be set to 1, which results in a set of patches per frame.

In your case, I believe what you want is to remove the backbone_tubelet_size key and instead use num_frames: 3 under model_args - notice this is an argument we pass to the model factory rather than to the model directly, so the backbone_ prefix is not necessary.

encoder x decoder interplay in PrithviModelFactory

The PrithviModelFactory relies on the feature_info attribute of the encoder in order to create a decoder with an appropriate number of channels for the dimension of the embedding. You can see how this has a multiplicative relationship with the num_frames parameter, as expected: https://github.com/IBM/terratorch/blob/ffc431a09ab8bdafeb184e6668e4450b264597c2/terratorch/models/backbones/vit_encoder_decoder.py#L205-L206

prepare_features_for_image_model

A final piece that is missing is what we do to the output of the encoder before passing it to the decoder. In some cases, for example in a classic UNet, we may not need to do anything. However, in many cases, we do need to perform some reshaping, permuting, pooling, ... . This is what prepare_features_for_image_model takes care of.

This function can be custom made by the user and passed to the PrithviModelFactory, however, when it is not, a default is used, which is defined by the encoder (if it is not defined by the encoder, the identity function is used). For the ViT model, we can see that function here: https://github.com/IBM/terratorch/blob/ffc431a09ab8bdafeb184e6668e4450b264597c2/terratorch/models/backbones/vit_encoder_decoder.py#L459-L488

For each embedding produced by the model, it will reshape them such that it goes from a unidimensional set of tokens to something more image-like, with a height and width, which can be processed using, e.g. CNNs.

In your case, the classification head still expects the input to be in this image-like shape. This is two main reasons:

  1. For consistency with segmentation heads, making it easier to swap out classification and segmentation for the same underlying model
  2. In order to enable CNN-like decoders to work with the classification head

You can see how it immediately inverts this reshaping in its forward pass however: https://github.com/IBM/terratorch/blob/ffc431a09ab8bdafeb184e6668e4450b264597c2/terratorch/models/heads/classification_head.py#L48-L49

Let me know if this helps. I'm happy to discuss it further together with how all this fits into your specific use case :)

CarlosGomes98 commented 2 weeks ago

On further discussion, the issue seems to be that prepare_feature_for_image_size and feature_info seem to disregard tubelet_size, effectively setting it to 1.

In order to take it into account, we can see that the effective num_frames the model operates with is the number of frames passed to it by the dataset divided by the tubelet size. This is because the patch embedding is done with a stride of tubelet_size, resulting in num_frames // tubelet_size sets of spatial embeddings.