mikeskaug / prithvi-change-detection

Using the Prithvi geospatial foundation model for change detection
1 stars 0 forks source link

Embedding reshaping and segmentation head #2

Open mikeskaug opened 6 months ago

mikeskaug commented 6 months ago

I've been working on setting up the segmentation head to accept the encoder output and I have come upon a question which relates to the 2-image input in this use case.

My plan had been to combine the "pre" and "post" disaster images into a single model input: https://github.com/mikeskaug/prithvi-change-detection/blob/8214f9482d59ff5230fb29bc640efbe9dbbea49d/dataset.py#L35

As a result, the encoder output is a longer sequence which contains information from both images. (How the information is distributed across the sequence is something I'm still confused about. From what I've seen, most users assume that the output tokens correspond to the input tokens in a way that you can, for example, drop the class token and reshape to 2 spatial dimensions (num_patches x num_patches x embedding_size). For now, I'm assuming the same thing even though I don't understand why it's that simple.)

My initial thought is to reshape the encoder output by splitting the sequence in half, reshaping each to [num_patches x num_patches x embedding_size] and then stacking them in the embedding/channel dimension: https://github.com/mikeskaug/prithvi-change-detection/blob/8214f9482d59ff5230fb29bc640efbe9dbbea49d/model.py#L131

Which gives tensor shape = [num_patches x num_patches x 2*embedding_size]

If I pass this through the decoder stack as it's currently designed, https://github.com/mikeskaug/prithvi-change-detection/blob/8214f9482d59ff5230fb29bc640efbe9dbbea49d/model.py#L110

the convolutional filters are going to be convolved and summed across features from both images. That seems weird, but maybe it's a simple way to start? Another option would be to use groups=2 in the torch.nn.Conv2d layers to have unique filters for the two images. And then apply the final classification layer to all of the output feature maps?

Does anyone have any thoughts on this reshaping or how the decoding/segmentation stack should be applied?

blumenstiel commented 6 months ago

I think it is correct to use the output tokens from the embedding in this way.

Combining the embeddings of both timestamps is one possibility. Two other possible approaches:

I don't know which approach would work best. Maybe another paper already compares these approaches. Otherwise, it could be a good experiment to do the comparison.

blumenstiel commented 6 months ago

The classifier only has 4 classes. How do you handle the no-building class? My approach would be to add it as a fifth class.

mikeskaug commented 6 months ago

I think it is correct to use the output tokens from the embedding in this way.

Combining the embeddings of both timestamps is one possibility. Two other possible approaches:

  • Only use the embeddings of t1 as input for the decoder.
  • Process each timestamp separately (concatenate along new dim/batch dim or use the groups=2 setting). I don't know if the weights are shared between the groups, you might want to check this. You can split the embeddings and apply the classifier twice. This would result in two segmentation maps. If the dataset does not include segmentation maps for t0, we could assume that all builds have no damage.

I don't know which approach would work best. Maybe another paper already compares these approaches. Otherwise, it could be a good experiment to do the comparison.

Thanks for you input on this. Yes, maybe this will require some experimentation. I will also search the literature for hints.

For what it's worth, I think using the group=2 argument means different weights for the different groups:

At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.

mikeskaug commented 6 months ago

The classifier only has 4 classes. How do you handle the no-building class? My approach would be to add it as a fifth class.

Oh, wow! Thank you for catching this. I was making a wrong assumption about how the data was labeled. I thought all the damage was labeled in each image, so anything not labeled was un-damaged, but clearly that is wrong. I will add an "un-classified" class for anything not labeled and exclude it from the loss calculation.

blumenstiel commented 6 months ago

Then I assume that using group=2 and adding a temporal dim (for shared weights) are two different approaches.

blumenstiel commented 6 months ago

Not sure if I would exclude the unknown class from the loss. It might lead to more false positive predictions. You could consider using a weighted loss instead.