allenai / satlaspretrain_models

Apache License 2.0
73 stars 12 forks source link

Issue with torchgeo_demo training #15

Open shashankvasisht opened 2 days ago

shashankvasisht commented 2 days ago

Hi, thank you for this amazing work !! I was trying out your torchgeo_demo.ipynb on my local machine. I only changed the 'max_epochs' parameter and ran the rest of the code blocks as it is.

I let the model train for about 60 epochs and saw that the loss was almost stagnant since the beginning and so was the accuracy. I understand that the model (swin_v2_B) is a transformer architecture and dataset used is very small (UC_Merced with only ~1200 train samples). Is this the reason for the model not able to learn ? Or am I doing something incorrect ?

Pls find the code and output screenshots below:

import os
import torch
import tempfile
from typing import Optional
from lightning.pytorch import Trainer

from torchgeo.models import Swin_V2_B_Weights, swin_v2_b
from torchgeo.datamodules import UCMercedDataModule
from torchgeo.trainers import ClassificationTask

# Experiment Arguments
batch_size = 8
num_workers = 2
max_epochs = 150
fast_dev_run = False

# Torchgeo lightning datamodule to initialize dataset
root = os.path.join(tempfile.gettempdir(), "ucm")
datamodule = UCMercedDataModule(
    root=root, batch_size=batch_size, num_workers=num_workers, download=True
)

# Custom ClassificationTask to load in the SatlasPretrain model
class SatlasClassificationTask(ClassificationTask):
    def configure_models(self):
        weights = Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS
        self.model = swin_v2_b(weights)

        # Replace first layer's input channels with the task's number input channels.
        first_layer = self.model.features[0][0]
        self.model.features[0][0] = torch.nn.Conv2d(3,
                                    first_layer.out_channels,
                                    kernel_size=first_layer.kernel_size,
                                    stride=first_layer.stride,
                                    padding=first_layer.padding,
                                    bias=(first_layer.bias is not None))

        # Replace last layer's output features with the number classes.
        self.model.head = torch.nn.Linear(in_features=1024, out_features=self.hparams["num_classes"], bias=True)

    def on_validation_epoch_end(self):
        # Accessing metrics logged during the current validation epoch
        val_loss = self.trainer.callback_metrics.get('val_loss', 'N/A')
        val_acc = self.trainer.callback_metrics.get('val_OverallAccuracy', 'N/A')
        print(f"Epoch {self.current_epoch} Validation - Loss: {val_loss}, Accuracy: {val_acc}")

# Initialize the Classifcation Task
task = SatlasClassificationTask(num_classes=21)

# Initialize the training code.
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
default_root_dir = os.path.join(tempfile.gettempdir(), "experiments")

trainer = Trainer(
    accelerator=accelerator,
    default_root_dir=default_root_dir,
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
    min_epochs=1,
    max_epochs=max_epochs,
)

# Train
trainer.fit(model=task, datamodule=datamodule)

image image

favyen2 commented 2 days ago

I tried it and got the same issue, I think it needs lower learning rate:

task = SatlasClassificationTask(num_classes=21, lr=1e-4)

Let me know if this works for you and I can adjust the notebook accordingly (and fix that SENTINEL2_RGB_SI_SATLAS vs SENTINEL2_SI_RGB_SATLAS problem too, I think they renamed the weights in the final version).

shashankvasisht commented 2 days ago

Hi, thanks for your reply. I did also try using a stepLR scheduler, but still saw a similar behaviour.

favyen2 commented 2 days ago

It doesn't work for you when passing lr=1e-4 to SatlasClassificationTask (so that the lr is reduced from the beginning of training)?

shashankvasisht commented 1 day ago

Okay, yes ! I saw that my Scheduler implementation had an error. I fixed it and kept the initial LR as 1e-4. Now it seems to work fine. Thank you !

import os
import torch
import tempfile
from typing import Optional
from lightning.pytorch import Trainer
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torchgeo.models import Swin_V2_B_Weights, swin_v2_b
from torchgeo.datamodules import UCMercedDataModule
from torchgeo.trainers import ClassificationTask

# Experiment Arguments
batch_size = 8
num_workers = 2
max_epochs = 150
fast_dev_run = False

# Torchgeo lightning datamodule to initialize dataset
root = os.path.join(tempfile.gettempdir(), "ucm")
datamodule = UCMercedDataModule(
    root=root, batch_size=batch_size, num_workers=num_workers, download=True
)

# Custom ClassificationTask to load in the SatlasPretrain model
class SatlasClassificationTask(ClassificationTask):
    def configure_models(self):
        weights = Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS
        self.model = swin_v2_b(weights)

        # Replace first layer's input channels with the task's number input channels.
        first_layer = self.model.features[0][0]
        self.model.features[0][0] = torch.nn.Conv2d(3,
                                    first_layer.out_channels,
                                    kernel_size=first_layer.kernel_size,
                                    stride=first_layer.stride,
                                    padding=first_layer.padding,
                                    bias=(first_layer.bias is not None))

        # Replace last layer's output features with the number classes.
        self.model.head = torch.nn.Linear(in_features=1024, out_features=self.hparams["num_classes"], bias=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]  # Return the optimizer and scheduler (optional)

    def on_validation_epoch_end(self):
        # Accessing metrics logged during the current validation epoch
        val_loss = self.trainer.callback_metrics.get('val_loss', 'N/A')
        val_acc = self.trainer.callback_metrics.get('val_OverallAccuracy', 'N/A')
        print(f"Epoch {self.current_epoch} Validation - Loss: {val_loss}, Accuracy: {val_acc}")

# Initialize the Classifcation Task
task = SatlasClassificationTask(num_classes=21)

# Initialize the training code.
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
default_root_dir = os.path.join(tempfile.gettempdir(), "experiments")

trainer = Trainer(
    accelerator=accelerator,
    default_root_dir=default_root_dir,
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
    min_epochs=1,
    max_epochs=max_epochs,
)

# Train
trainer.fit(model=task, datamodule=datamodule)

image