lufficc / SSD

High quality, fast, modular reference implementation of SSD in PyTorch
MIT License
1.51k stars 385 forks source link

Fine Tuning a Trained Model #205

Open GaetanoPrudente opened 3 years ago

GaetanoPrudente commented 3 years ago

Is it possible to know how to fine tune a trained model with a different dataset (augmented data)? There is no documentation about it or about the options in order to do it

ksv87 commented 1 week ago

in train.py add code from yolox repo https://github.com/Megvii-BaseDetection/YOLOX/blob/f00a798c8bf59f43ab557a2f3d566afa831c8887/yolox/utils/checkpoint.py#L11

def load_ckpt(model, ckpt, logger):
    model_state_dict = model.state_dict()
    load_dict = {}
    for key_model, v in model_state_dict.items():
        if key_model not in ckpt:
            logger.warning(
                "{} is not in the ckpt. Please double check and see if this is desired.".format(
                    key_model
                )
            )
            continue
        v_ckpt = ckpt[key_model]
        if v.shape != v_ckpt.shape:
            logger.warning(
                "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
                    key_model, v_ckpt.shape, key_model, v.shape
                )
            )
            continue
        load_dict[key_model] = v_ckpt

    model.load_state_dict(load_dict, strict=False)
    return model

and add in https://github.com/lufficc/SSD/blob/68dc0a20efaf3997e58b616afaaaa21bf8ca3c05/train.py#L26-L29

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    if args.finetuning_file is not None:
        ftc = torch.load(args.finetuning_file)["model"]
        model = load_ckpt(model, ftc, logger)
        logger.info(f"Loaded for finetuning {args.finetuning_file}")

    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = make_optimizer(cfg, model, lr)

in https://github.com/lufficc/SSD/blob/68dc0a20efaf3997e58b616afaaaa21bf8ca3c05/train.py#L50-L57

    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "--finetuning-file",
        default=None,
        metavar="FILE",
        help="path to model for finetuning",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)

after run withpython train.py --config-file configs\efficient_net_b3_ssd300_graff.yaml --finetuning-file efficient_net_b3_ssd300_voc0712.pth