Open Foxigod opened 3 weeks ago
Hi!
This looks like it was an issue of confusion with the (admittedly numerous) configuration parameters for the backbone. Let me try to clarify in the following:
num_frames
This parameter controls the number of frames (the length of the temporal sequence) that the Prithvi ViT model expects. The batch should have the dimensions in the order Batch, Channels, Time, Height, Width
, where Time should be = num_frames
. num_frames=1
is very common, so it is treated as a special case - if Time = 1, this dimension will be automatically created if it does not exist: https://github.com/IBM/terratorch/blob/ffc431a09ab8bdafeb184e6668e4450b264597c2/terratorch/models/backbones/vit_encoder_decoder.py#L433-L434
tubelet_size
This is a more advanced parameter. It determines the depth of the 3D convolution used to for patch embedding in this transformer. In the majority of cases, this should be set to 1, which results in a set of patches per frame.
In your case, I believe what you want is to remove the backbone_tubelet_size
key and instead use num_frames: 3
under model_args
- notice this is an argument we pass to the model factory rather than to the model directly, so the backbone_
prefix is not necessary.
PrithviModelFactory
The PrithviModelFactory
relies on the feature_info
attribute of the encoder in order to create a decoder with an appropriate number of channels for the dimension of the embedding. You can see how this has a multiplicative relationship with the num_frames
parameter, as expected: https://github.com/IBM/terratorch/blob/ffc431a09ab8bdafeb184e6668e4450b264597c2/terratorch/models/backbones/vit_encoder_decoder.py#L205-L206
prepare_features_for_image_model
A final piece that is missing is what we do to the output of the encoder before passing it to the decoder. In some cases, for example in a classic UNet, we may not need to do anything. However, in many cases, we do need to perform some reshaping, permuting, pooling, ... . This is what prepare_features_for_image_model
takes care of.
This function can be custom made by the user and passed to the PrithviModelFactory
, however, when it is not, a default is used, which is defined by the encoder (if it is not defined by the encoder, the identity function is used). For the ViT model, we can see that function here: https://github.com/IBM/terratorch/blob/ffc431a09ab8bdafeb184e6668e4450b264597c2/terratorch/models/backbones/vit_encoder_decoder.py#L459-L488
For each embedding produced by the model, it will reshape them such that it goes from a unidimensional set of tokens to something more image-like, with a height and width, which can be processed using, e.g. CNNs.
In your case, the classification head still expects the input to be in this image-like
shape. This is two main reasons:
You can see how it immediately inverts this reshaping in its forward pass however: https://github.com/IBM/terratorch/blob/ffc431a09ab8bdafeb184e6668e4450b264597c2/terratorch/models/heads/classification_head.py#L48-L49
Let me know if this helps. I'm happy to discuss it further together with how all this fits into your specific use case :)
On further discussion, the issue seems to be that prepare_feature_for_image_size
and feature_info
seem to disregard tubelet_size
, effectively setting it to 1.
In order to take it into account, we can see that the effective num_frames
the model operates with is the number of frames passed to it by the dataset divided by the tubelet size. This is because the patch embedding is done with a stride of tubelet_size
, resulting in num_frames // tubelet_size
sets of spatial embeddings.
@CarlosGomes98 Describe the issue I'm interested in fine-tuning a ViT model with the patch-embedding size set to something greater than 1 for the temporal dimension. To do this, I at least have to perform something similar to #118. When running after this change, I get an error like the following:
which seems to be sprouted from:
I believe the issue to be inconsistencies between the
TemporalViTEncoder
'sprepare_features_for_image_model
function and the calculation of these "num_chs" values in theTemporalViTEncoder
, however I'm not fully certain.Can I ask someone to explain to me this transformation that is going on in
prepare_features_for_image_model
? What is its purpose, and do you have any references explaining why it is formulated the way it is?To Reproduce (optional, but appreciated) I'm performing a classification task with the following config entry with the aforementioned change like #118: