Closed Bencpr closed 5 months ago
We actually have similar up-sampling layers in the decoder to segment at the input resolution. The original code is here https://github.com/allenai/satlas/blob/main/satlas/model/model.py#L350 but looks like it did not make it into this version @piperwolters can you take a look? We might want to always include the upsampling module for segmentation tasks.
What issues did you get when you prepended those up-sampling layers manually?
Thanks @favyen2 ,
The main issue that I have is that the head simply does not converge at all and segmentation maps are garbage.
Thanks for the code pointer of the Upsampler, how do you use it in practice ? in combination with or in place of a segmentation head ?
Do you have code snippet for how you are using the model? Also for segmentation it will be important to enable FPN because otherwise the features at 1/4 original image resolution won't have the context from the deeper layers, so if you aren't passing fpn=True then that is one thing to try.
I'm using Torchgeo, so I override the _configuremodel method and inject the satlas pretrained model.
from torchgeo.trainers import SemanticSegmentationTask
from satlaspretrain_models.utils import SatlasPretrain_weights
weights_manager = satlaspretrain_models.Weights()
# Custom ClassificationTask to load in the SatlasPretrain model
class SatlasSemanticSegmentationTask(SemanticSegmentationTask):
def __init__(self, model_identifier: str, fpn: bool = True, pretrained: bool = True, *args, **kwargs):
self.model_identifier = model_identifier
self.pretrained = pretrained
self.fpn = fpn
# call super method
super().__init__(*args, **kwargs)
def configure_models(self):
self.model = weights_manager.get_pretrained_model(
model_identifier=self.model_identifier,
fpn=self.fpn,
head=satlaspretrain_models.Head.SEGMENT,
num_categories=self.hparams["num_classes"]
)
# add upconv layers to Head
head_in_channels = self.model.head.layers[0][0].in_channels
upconv = torch.nn.Sequential(
torch.nn.Conv2d(head_in_channels, head_in_channels, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
torch.nn.Conv2d(head_in_channels, head_in_channels, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
)
# place upconv layers before provided head layers
modules = list(upconv.children()) + list(self.model.head.layers.children())
self.model.head.layers = torch.nn.Sequential(*modules)
# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.head.parameters():
param.requires_grad = True
Call it with
task = SatlasSemanticSegmentationTask(
num_classes=2,
freeze_backbone=True,
model_identifier="Aerial_SwinB_SI",
pretrained=True,
fpn=True
)
Then I use the torchgeo Trainer on a custom semantic segmentation datamodule (binary).
When using a torchgeo's default Resnet50 model, I obtain satisfying segmentation outputs, so everything clear from the datamodule side.
Thanks
Hi @favyen2 , do you have any insight on how to use the upsampler ?
Thank you :)
I'll take a look.
For the Upsampler, that module should go after the FPN but before the SimpleHead. It inputs the feature maps that come out of the FPN, and it prepends the upsampled result and returns combined list (so element 0 is upsampled features, elements 1-4 are from FPN). During pre-training, the object detection head uses the FPN outputs, while segmentation/regression tasks use up-sampled features only.
The upsampler is now included in the satlaspretrain_models
library (see https://github.com/allenai/satlaspretrain_models/pull/7) so you could try that as well.
I get error when using the SatlasSemanticSegmentationTask you provided
TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not tuple
since the segmentation head returns (output, loss) tuple, but only the output should be retained for computing cross entropy loss in the torchgeo trainer task.
Hi @favyen2, great news thanks ! I'll try the upsampler right now.
So both the Upsampler + Head should be unfrozen to be trainable, not only the Head. It may be interesting to have also a pretrained upsampler, what do you think ?
About your error, sorry I did not provide a custom training_step
and validation_step
to allow for returned tuple in the loss function:
def training_step(
self, batch, batch_idx: int, dataloader_idx: int = 0
):
"""Compute the training loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
The loss tensor.
"""
x = batch["image"]
y = batch["mask"]
batch_size = x.shape[0]
probas, loss = self.model(x, y.squeeze())
y_hat = torch.max(probas, dim=1).indices
self.log("train_loss", loss, batch_size=batch_size)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics, batch_size=batch_size, on_epoch=True, prog_bar=True)
return loss
def validation_step(
self, batch, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Compute the validation loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch["image"]
y = batch["mask"]
batch_size = x.shape[0]
probas, loss = self.model(x, y.squeeze())
y_hat = torch.max(probas, dim=1).indices
self.log("val_loss", loss, batch_size=batch_size)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics, batch_size=batch_size, on_epoch=True, prog_bar=True)
Hi again @favyen2 , I did not succeed to make the segmentation training work using the upsampler you merged in the main branch.
Do you have a working example, for example using torchgeo for reproducibility ?
Thank you Ben
Have you tried reducing learning rate? I get good performance using torchgeo with the SatlasSemanticSegmentationTask you provided, but needed to set learning rate to 1e-4.
Hi there,
I've successfully fine-tuned the Aerial_SwinB_SI pretrained model on a classification task.
I am now trying to fine-tune this model on a semantic segmentation task with a custom decoder/head that also performs upscaling on the model's output to the original image size. Indeed, the image embedding's largest size is 1/4th of the original size.
I have tried various options, such as Upsample and ConvTranspose2D layers in combination with standard Conv2D/Relu layers. Example here provided by CLAY pretrained model :
Source: https://clay-foundation.github.io/model/model_finetuning.html
I first tried simple decoder architectures that I position juste before the provided "HEAD=segment" head, such as
So far, I didn't get anything close to useful.
In the paper, when SATLAS models are finetuned on the segmentation task, do you segment 1/4th of the original image ?
Any idea to make this work is welcome! :)
Thanks, Ben