InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.15k stars 228 forks source link

How can I finetune the large bitrate model (λ=0.05) to obtain the low bitrate model (i.e. λ=0.013) #280

Closed wonlee2019 closed 5 months ago

wonlee2019 commented 5 months ago

For saving computational resources, I think the training strategy from well-trained high bitrate model to obtain a low bitrate model is necessary.

Problem> when I type "--lambda 0.0013 --load-checkpoint /../0.05checkpoint_best.pth.tar" and the training process doesn't start with "epoch0". It just exit without any error.

How can I work on this training strategy? Thanks a lot!

YodaEmbedding commented 5 months ago

Why this happens

The epoch is saved inside the checkpoint, along with the LR scheduler state, and other state:

            save_checkpoint(
                {
                    "epoch": epoch,
                    "state_dict": net.state_dict(),
                    "loss": loss,
                    "optimizer": optimizer.state_dict(),
                    "aux_optimizer": aux_optimizer.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                },
                is_best,
            )

A solution

When loading the checkpoint, you may want to skip loading some of this state. For example, modify the training code to:

def parse_args(argv):
    ...
    parser.add_argument("--start-fresh", action="store_true", help="Reset training state.")
    ...

def main(argv):
    ...

    last_epoch = 0
    if args.checkpoint:  # load from previous checkpoint
        print("Loading", args.checkpoint)
        checkpoint = torch.load(args.checkpoint, map_location=device)
        net.load_state_dict(checkpoint["state_dict"])
        if not args.start_fresh:
            last_epoch = checkpoint["epoch"] + 1
            optimizer.load_state_dict(checkpoint["optimizer"])
            aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

    ...

Then, just run your training command with an additional --start-fresh flag.

wonlee2019 commented 5 months ago

Thanks a lot! I think I know where the problem lies. Thank you Ulhaq!