allenai / satlaspretrain_models

Apache License 2.0
71 stars 12 forks source link

Segmentation Task: looking for a decoder head upsample and segment to the input image size #6

Closed Bencpr closed 5 months ago

Bencpr commented 5 months ago

Hi there,

I've successfully fine-tuned the Aerial_SwinB_SI pretrained model on a classification task.

I am now trying to fine-tune this model on a semantic segmentation task with a custom decoder/head that also performs upscaling on the model's output to the original image size. Indeed, the image embedding's largest size is 1/4th of the original size.

I have tried various options, such as Upsample and ConvTranspose2D layers in combination with standard Conv2D/Relu layers. Example here provided by CLAY pretrained model :

Model(
  (decoder): Sequential(
    (0): Conv2d(4608, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): Upsample(scale_factor=2.0, mode='nearest')
    (2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Upsample(scale_factor=2.0, mode='nearest')
    (5): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Upsample(scale_factor=2.0, mode='nearest')
    (8): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Upsample(scale_factor=2.0, mode='nearest')
    (11): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): Upsample(scale_factor=2.0, mode='nearest')
  )
)

Source: https://clay-foundation.github.io/model/model_finetuning.html

I first tried simple decoder architectures that I position juste before the provided "HEAD=segment" head, such as

SimpleHead(
  (layers): Sequential(
    # custom insertion HERE
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Upsample(scale_factor=2.0, mode='nearest')
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): Upsample(scale_factor=2.0, mode='nearest')
    # HEAD=segment provided by code
    (6): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (7): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

So far, I didn't get anything close to useful.

In the paper, when SATLAS models are finetuned on the segmentation task, do you segment 1/4th of the original image ?

Any idea to make this work is welcome! :)

Thanks, Ben

favyen2 commented 5 months ago

We actually have similar up-sampling layers in the decoder to segment at the input resolution. The original code is here https://github.com/allenai/satlas/blob/main/satlas/model/model.py#L350 but looks like it did not make it into this version @piperwolters can you take a look? We might want to always include the upsampling module for segmentation tasks.

What issues did you get when you prepended those up-sampling layers manually?

Bencpr commented 5 months ago

Thanks @favyen2 ,

The main issue that I have is that the head simply does not converge at all and segmentation maps are garbage.

Thanks for the code pointer of the Upsampler, how do you use it in practice ? in combination with or in place of a segmentation head ?

favyen2 commented 5 months ago

Do you have code snippet for how you are using the model? Also for segmentation it will be important to enable FPN because otherwise the features at 1/4 original image resolution won't have the context from the deeper layers, so if you aren't passing fpn=True then that is one thing to try.

Bencpr commented 5 months ago

I'm using Torchgeo, so I override the _configuremodel method and inject the satlas pretrained model.

from torchgeo.trainers import SemanticSegmentationTask
from satlaspretrain_models.utils import SatlasPretrain_weights

weights_manager = satlaspretrain_models.Weights()

# Custom ClassificationTask to load in the SatlasPretrain model
class SatlasSemanticSegmentationTask(SemanticSegmentationTask):

    def __init__(self, model_identifier: str, fpn: bool = True, pretrained: bool = True, *args, **kwargs):

        self.model_identifier = model_identifier
        self.pretrained = pretrained
        self.fpn = fpn
        # call super method
        super().__init__(*args, **kwargs)

    def configure_models(self):

        self.model = weights_manager.get_pretrained_model(
            model_identifier=self.model_identifier,
            fpn=self.fpn,
            head=satlaspretrain_models.Head.SEGMENT,
            num_categories=self.hparams["num_classes"]
        )

        # add upconv layers to Head
        head_in_channels = self.model.head.layers[0][0].in_channels
        upconv = torch.nn.Sequential(
            torch.nn.Conv2d(head_in_channels, head_in_channels, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
            torch.nn.Conv2d(head_in_channels, head_in_channels, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
        )

        # place upconv layers before provided head layers
        modules = list(upconv.children()) + list(self.model.head.layers.children())
        self.model.head.layers = torch.nn.Sequential(*modules)

        # Freeze backbone and unfreeze classifier head
        if self.hparams["freeze_backbone"]:
            for param in self.model.parameters():
                param.requires_grad = False
            for param in self.model.head.parameters():
                param.requires_grad = True

Call it with

task = SatlasSemanticSegmentationTask(
    num_classes=2,
    freeze_backbone=True,
    model_identifier="Aerial_SwinB_SI",
    pretrained=True,
    fpn=True
)

Then I use the torchgeo Trainer on a custom semantic segmentation datamodule (binary).

When using a torchgeo's default Resnet50 model, I obtain satisfying segmentation outputs, so everything clear from the datamodule side.

Thanks

Bencpr commented 5 months ago

Hi @favyen2 , do you have any insight on how to use the upsampler ?

Thank you :)

favyen2 commented 5 months ago

I'll take a look.

For the Upsampler, that module should go after the FPN but before the SimpleHead. It inputs the feature maps that come out of the FPN, and it prepends the upsampled result and returns combined list (so element 0 is upsampled features, elements 1-4 are from FPN). During pre-training, the object detection head uses the FPN outputs, while segmentation/regression tasks use up-sampled features only.

favyen2 commented 5 months ago

The upsampler is now included in the satlaspretrain_models library (see https://github.com/allenai/satlaspretrain_models/pull/7) so you could try that as well.

I get error when using the SatlasSemanticSegmentationTask you provided

TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not tuple

since the segmentation head returns (output, loss) tuple, but only the output should be retained for computing cross entropy loss in the torchgeo trainer task.

Bencpr commented 5 months ago

Hi @favyen2, great news thanks ! I'll try the upsampler right now.

So both the Upsampler + Head should be unfrozen to be trainable, not only the Head. It may be interesting to have also a pretrained upsampler, what do you think ?

About your error, sorry I did not provide a custom training_step and validation_step to allow for returned tuple in the loss function:

def training_step(
  self, batch, batch_idx: int, dataloader_idx: int = 0
):
  """Compute the training loss and additional metrics.

  Args:
      batch: The output of your DataLoader.
      batch_idx: Integer displaying index of this batch.
      dataloader_idx: Index of the current dataloader.

  Returns:
      The loss tensor.
  """
  x = batch["image"]
  y = batch["mask"]
  batch_size = x.shape[0]
  probas, loss = self.model(x, y.squeeze())
  y_hat = torch.max(probas, dim=1).indices
  self.log("train_loss", loss, batch_size=batch_size)
  self.train_metrics(y_hat, y)
  self.log_dict(self.train_metrics, batch_size=batch_size, on_epoch=True, prog_bar=True)
  return loss

def validation_step(
  self, batch, batch_idx: int, dataloader_idx: int = 0
) -> None:
  """Compute the validation loss and additional metrics.

  Args:
      batch: The output of your DataLoader.
      batch_idx: Integer displaying index of this batch.
      dataloader_idx: Index of the current dataloader.
  """
  x = batch["image"]
  y = batch["mask"]
  batch_size = x.shape[0]
  probas, loss = self.model(x, y.squeeze())
  y_hat = torch.max(probas, dim=1).indices
  self.log("val_loss", loss, batch_size=batch_size)
  self.val_metrics(y_hat, y)
  self.log_dict(self.val_metrics, batch_size=batch_size, on_epoch=True, prog_bar=True)
Bencpr commented 5 months ago

Hi again @favyen2 , I did not succeed to make the segmentation training work using the upsampler you merged in the main branch.

Do you have a working example, for example using torchgeo for reproducibility ?

Thank you Ben

favyen2 commented 5 months ago

Have you tried reducing learning rate? I get good performance using torchgeo with the SatlasSemanticSegmentationTask you provided, but needed to set learning rate to 1e-4.