allenai / satlaspretrain_models

Apache License 2.0
75 stars 13 forks source link

Use ignore_index for the loss and add support for infrared band in the Heads #10

Open killian31 opened 5 months ago

killian31 commented 5 months ago

Contents

This PR brings two features:

Exemple Usage

import satlaspretrain_models
import torch

num_classes = 10
# load model
model = satlaspretrain_models.Weights().get_pretrained_model(
    model_identifier="Aerial_SwinB_SI",
    fpn=True,
    head=satlaspretrain_models.Head.SEGMENT,
    num_categories=num_classes,
    additional_bands=1,
    ignore_index=255,
)

# freeze backbone, fpn and unfreeze upsampler and head
for param in model.parameters():
    param.requires_grad = False
for param in model.upsample.parameters():
    param.requires_grad = True
for param in model.head.parameters():
    param.requires_grad = True

batch_size = 8
num_channels = 3 + 1 # rgb and infrared

# dummy input 
x = torch.randint(0, 256, (batch_size, num_channels, 512, 512)).float() / 255.0
# dummy mask
y = torch.randint(0, num_classes, (batch_size, 512, 512))

model_input = x[:, :3, :, :]
other_bands = x[:, 3:, :, :]
# compute model outputs
out_backbone = model.backbone(model_input)
out_fpn = model.fpn(out_backbone)
out_upsample = model.upsample(out_fpn)

# add infrared data to the output of the upsampler (at index 0 since it is the only one used in the head)
out_upsample[0] = torch.cat((out_upsample[0], other_bands.unsqueeze(1)), dim=1)
# run the head
probas, loss = model.head(model_input, out_upsample, y)
y_hat = torch.argmax(probas, dim=1) # predictions
killian31 commented 2 months ago

Hello!

We are evaluating Satlas with our own data containing 4 bands (R, G, B, NIR) and adding a 4th band as input to the segmentation head does give better results. We are using this branch for now, but it would help us a lot if this could be added to your repo! Could you maybe think about it?

Thank you, Killian

Bencpr commented 2 months ago

Hello!

We are evaluating Satlas with our own data containing 4 bands (R, G, B, NIR) and adding a 4th band as input to the segmentation head does give better results. We are using this branch for now, but it would help us a lot if this could be added to your repo! Could you maybe think about it?

Thank you, Killian

Hi, very interested by both changes proposed here by @killian31 :

thanks! 💯