Closed Tharunkumar01 closed 1 year ago
Hello @Tharunkumar01, First of all, thank you. I will add an automatic model saving feature based on the validation step (like the save best model only) and number of epochs. However, since I am busy, you could customize the trainer class by adding a simple step that saves a checkpoint for each epoch until this feature is implementing. Example Custom Trainer:
from backbones_unet.utils.trainer import Trainer
class CustomTrainer(Trainer):
def __init__(
self,
model,
criterion,
optimizer,
epochs,
save_path=None,
scaler=None,
lr_scheduler=None,
device=None
):
super().__init__(
model=model,
criterion=criterion,
optimizer=optimizer,
epochs=epochs,
scaler=scaler,
lr_scheduler=lr_scheduler,
device=device
)
self.save_path = save_path
def fit(self, train_loader, val_loader):
"""
Customized fit function
"""
# attributes
self.train_losses_ = torch.zeros(self.epochs)
self.val_losses_ = torch.zeros(self.epochs)
# ---- train process ----
for epoch in trange(1, self.epochs + 1, desc='Traning Model on {} epochs'.format(self.epochs)):
# train
self._train_one_epoch(train_loader, epoch)
# validate
self._evaluate(val_loader, epoch)
# model save
if self.save_path:
torch.save(self.model.state_dict(), self.save_path + f'ckpt_{epoch}.pt')
Thanks for creating this library, I was trying to implement the same, coz segmentation models pytorch library doesn't really provide the convnext as encoder.
I hope you implement the saving functionality soon..
Thank you.