plemeri / UACANet

Official PyTorch implementation of UACANet: Uncertainty Augmented Context Attention for Polyp Segmentation (ACMMM 2021)
MIT License
142 stars 37 forks source link

About checkpoint saved #14

Closed cqlouis closed 1 year ago

cqlouis commented 1 year ago

In run/Train.py, Line122

if epoch % opt.Train.Checkpoint.checkpoint_epoch == 0: torch.save(model.module.state_dict() if args.device_num > 1 else model.state_dict( ), os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth'))

To my understanding, this code fragment save checkpoint by each 20 epochs , this can not ensure the checkpoint saved is the optimal during training.

And in Line 130, if args.local_rank <= 0: torch.save(model.module.state_dict() if args.device_num > 1 else model.state_dict( ), os.path.join(opt.Train.Checkpoint.checkpoint_dir, 'latest.pth')) this code just save the weights of last epoch, and it also can not ensure the checkpoint saved is the optimal.

Why don't you save the optimal checkpoint? Would you mind explaining it for me? Many thanks to you! Happy New Year!

plemeri commented 1 year ago

Hi, the purpose of saving model in this repository during training session is to resume training if the training stops unexpectedly. We only save every 20 epochs since it took only few minutes per epoch. However, it is a good idea to save best checkpoint by evaluating the model every epoch with validation dataset, but we do not consider to make validation dataset.

Moreover, I agree with the fact that it might not be the optimal checkpoint, but to be clear with the training setting for fair comparison on test benchmark, we only use the latest checkpoint. It would be cheating if we find and use checkpoints in the middle of training.

Thanks and happy new year to you too