Closed wonlee2019 closed 7 months ago
The epoch is saved inside the checkpoint, along with the LR scheduler state, and other state:
save_checkpoint(
{
"epoch": epoch,
"state_dict": net.state_dict(),
"loss": loss,
"optimizer": optimizer.state_dict(),
"aux_optimizer": aux_optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
},
is_best,
)
When loading the checkpoint, you may want to skip loading some of this state. For example, modify the training code to:
def parse_args(argv):
...
parser.add_argument("--start-fresh", action="store_true", help="Reset training state.")
...
def main(argv):
...
last_epoch = 0
if args.checkpoint: # load from previous checkpoint
print("Loading", args.checkpoint)
checkpoint = torch.load(args.checkpoint, map_location=device)
net.load_state_dict(checkpoint["state_dict"])
if not args.start_fresh:
last_epoch = checkpoint["epoch"] + 1
optimizer.load_state_dict(checkpoint["optimizer"])
aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
...
Then, just run your training command with an additional --start-fresh
flag.
Thanks a lot! I think I know where the problem lies. Thank you Ulhaq!
For saving computational resources, I think the training strategy from well-trained high bitrate model to obtain a low bitrate model is necessary.
Problem> when I type "--lambda 0.0013 --load-checkpoint /../0.05checkpoint_best.pth.tar" and the training process doesn't start with "epoch0". It just exit without any error.
How can I work on this training strategy? Thanks a lot!