mberkay0 / pretrained-backbones-unet

A PyTorch-based Python library with UNet architecture and multiple backbones for Image Semantic Segmentation.
MIT License
46 stars 9 forks source link

Model saving functionality missing #1

Closed Tharunkumar01 closed 1 year ago

Tharunkumar01 commented 1 year ago

Thanks for creating this library, I was trying to implement the same, coz segmentation models pytorch library doesn't really provide the convnext as encoder.

I hope you implement the saving functionality soon..

Thank you.

mberkay0 commented 1 year ago

Hello @Tharunkumar01, First of all, thank you. I will add an automatic model saving feature based on the validation step (like the save best model only) and number of epochs. However, since I am busy, you could customize the trainer class by adding a simple step that saves a checkpoint for each epoch until this feature is implementing. Example Custom Trainer:

from backbones_unet.utils.trainer import Trainer

class CustomTrainer(Trainer):
    def __init__(
        self,
        model, 
        criterion, 
        optimizer,
        epochs,
        save_path=None,
        scaler=None,
        lr_scheduler=None, 
        device=None
    ):
        super().__init__(
            model=model, 
            criterion=criterion, 
            optimizer=optimizer,
            epochs=epochs,
            scaler=scaler,
            lr_scheduler=lr_scheduler, 
            device=device
        )
        self.save_path = save_path

    def fit(self, train_loader, val_loader):
        """
        Customized fit function 
        """
        # attributes  
        self.train_losses_ = torch.zeros(self.epochs)
        self.val_losses_ = torch.zeros(self.epochs)
        # ---- train process ----
        for epoch in trange(1, self.epochs + 1, desc='Traning Model on {} epochs'.format(self.epochs)):
            # train
            self._train_one_epoch(train_loader, epoch)
            # validate
            self._evaluate(val_loader, epoch)
            # model save
            if self.save_path:
                torch.save(self.model.state_dict(), self.save_path + f'ckpt_{epoch}.pt')