WongKinYiu / yolor

implementation of paper - You Only Learn One Representation: Unified Network for Multiple Tasks (https://arxiv.org/abs/2105.04206)
GNU General Public License v3.0
1.98k stars 524 forks source link

Fine-tuning #195

Closed ilkergalipatak closed 2 years ago

ilkergalipatak commented 2 years ago

Hello. I trained one model with yolor but I want to add one more class my yolor model. I think I need fine tune my model new dataset. I require how to fine tune yolor model? Can you help for this situation?

LoveDH commented 2 years ago

just run python3 train.py --weight {previous weights}.pt ... if you have pretrained weights like some_model/weights/best.pt, the code below in train.py will catch it.

    # Model
    pretrained = weights.endswith(".pt")
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        model = Darknet(opt.cfg).to(device)  # create
        state_dict = {
            k: v
            for k, v in ckpt["model"].items()
            if model.state_dict()[k].numel() == v.numel()
        }
        model.load_state_dict(state_dict, strict=False)
        print(
            "Transferred %g/%g items from %s"
            % (len(state_dict), len(model.state_dict()), weights)
        )  # report
    else:
        model = Darknet(opt.cfg).to(device)  # create
ilkergalipatak commented 2 years ago

just run python3 train.py --weight {previous weights}.pt ... if you have pretrained weights like some_model/weights/best.pt, the code below in train.py will catch it.

    # Model
    pretrained = weights.endswith(".pt")
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        model = Darknet(opt.cfg).to(device)  # create
        state_dict = {
            k: v
            for k, v in ckpt["model"].items()
            if model.state_dict()[k].numel() == v.numel()
        }
        model.load_state_dict(state_dict, strict=False)
        print(
            "Transferred %g/%g items from %s"
            % (len(state_dict), len(model.state_dict()), weights)
        )  # report
    else:
        model = Darknet(opt.cfg).to(device)  # create

thank you for answer