gmberton / deep-visual-geo-localization-benchmark

Official code for CVPR 2022 (Oral) paper "Deep Visual Geo-localization Benchmark"
MIT License
186 stars 28 forks source link

Issue with Saving Checkpoint in Training Loop #26

Closed SeunghanYu closed 4 months ago

SeunghanYu commented 5 months ago

Hello, @gmberton!

I've encountered an issue when using the --resume option to load the best_model.pth file. Specifically, while the "model_state_dict" and "recalls" in the checkpoint file correctly store the weights and recall values for the best-performing epoch, the "epoch_num", "best_r5", and "not_improved_num" are from the epoch immediately before the best epoch.

For example, consider the following training log where epoch 77 achieves the highest R@5 of 79.1, after which the training stops:

2024-05-20 14:38:01   Start training epoch: 77
2024-05-20 14:38:01   Cache: 0 / 5
2024-05-20 14:40:31   Epoch[77](0/5): current batch triplet loss = 0.0058, average epoch triplet loss = 0.0059
2024-05-20 14:40:31   Cache: 1 / 5
2024-05-20 14:42:58   Epoch[77](1/5): current batch triplet loss = 0.0030, average epoch triplet loss = 0.0066
2024-05-20 14:42:58   Cache: 2 / 5
2024-05-20 14:45:25   Epoch[77](2/5): current batch triplet loss = 0.0096, average epoch triplet loss = 0.0065
2024-05-20 14:45:25   Cache: 3 / 5
2024-05-20 14:47:54   Epoch[77](3/5): current batch triplet loss = 0.0000, average epoch triplet loss = 0.0066
2024-05-20 14:47:54   Cache: 4 / 5
2024-05-20 14:50:23   Epoch[77](4/5): current batch triplet loss = 0.0039, average epoch triplet loss = 0.0065
2024-05-20 14:50:23   Finished epoch 77 in 0:12:22, average epoch triplet loss = 0.0065
2024-05-20 14:50:23   Extracting database features for evaluation/testing
2024-05-20 14:51:36   Extracting queries features for evaluation/testing
2024-05-20 14:52:17   Calculating recalls
2024-05-20 14:52:19   Recalls on val set < BaseDataset, msls - #database: 18871; #queries: 11084 >: R@1: 65.5, R@5: 79.1, R@10: 82.7, R@20: 86.0
2024-05-20 14:52:20   Improved: previous best R@5 = 78.4, current R@5 = 79.1

The best_model.pth file then contains:

epoch_num: <class 'int'>
  value: 77
model_state_dict: <class 'collections.OrderedDict'>
optimizer_state_dict: <class 'dict'>
recalls: <class 'numpy.ndarray'>
  value: [65.53590761 79.12306027 82.70479971 86.02490076]
best_r5: <class 'numpy.float64'>
  value: 78.44640923854205
not_improved_num: <class 'int'>
  value: 0

When I resume training with this checkpoint, the log shows:

2024-05-21 01:40:17   Loaded checkpoint: start_epoch_num = 77, current_best_R@5 = 78.4
2024-05-21 01:40:17   Resuming from epoch 77 with best recall@5 78.4

It appears that the checkpoint is saved before best_r5 and related variables are updated. I think this issue can be resolved by updating these variables before saving the checkpoint.

Current train.py code:

...

is_best = recalls[1] > best_r5

# Save checkpoint, which contains all training parameters
util.save_checkpoint(args, {
    "epoch_num": epoch_num, "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(), "recalls": recalls, "best_r5": best_r5,
    "not_improved_num": not_improved_num
}, is_best, filename="last_model.pth")

# If recall@5 did not improve for "many" epochs, stop training
if is_best:
    logging.info(f"Improved: previous best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}")
    best_r5 = recalls[1]
    best_epoch = epoch_num
    not_improved_num = 0
else:
    not_improved_num += 1
    logging.info(f"Not improved: {not_improved_num} / {args.patience}: best R@5 = {best_r5:.1f} at epoch: {best_epoch:.1f}, current R@5 = {recalls[1]:.1f}")
    if not_improved_num >= args.patience:
        logging.info(f"Performance did not improve for {not_improved_num} epochs. Stop training.")
        break

Proposed change:

is_best = recalls[1] > best_r5

# If recall@5 did not improve for "many" epochs, stop training
if is_best:
    logging.info(f"Improved: previous best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}")
    best_r5 = recalls[1]
    best_epoch = epoch_num
    not_improved_num = 0
else:
    not_improved_num += 1
    logging.info(f"Not improved: {not_improved_num} / {args.patience}: best R@5 = {best_r5:.1f} at epoch: {best_epoch:.1f}, current R@5 = {recalls[1]::.1f}")
    if not_improved_num >= args.patience:
        logging.info(f"Performance did not improve for {not_improved_num} epochs. Stop training.")
        break

# Save checkpoint, which contains all training parameters
util.save_checkpoint(args, {
    "epoch_num": epoch_num+1, "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(), "recalls": recalls, "best_r5": best_r5,
    "not_improved_num": not_improved_num
}, is_best, filename="last_model.pth")

Would this change be a proper solution to ensure that the best_model.pth file correctly reflects the best-performing epoch?

Looking forward to your reply. Thank you in advance!

gmberton commented 5 months ago

Hi, this looks good. Please run a couple experiments to make sure it runs as intended, then feel free to open a PR :)