Linardos / SalEMA

Simple vs complex temporal recurrences for video saliency prediction (BMVC 2019)
https://imatge-upc.github.io/SalEMA/
25 stars 11 forks source link

Some bugs in train.py #8

Closed Li-Jinquan closed 4 years ago

Li-Jinquan commented 4 years ago

Hi, in your train.py line 163 ` if args.pt_model != False:

In truth it's not None, we default to SalGAN or SalBCE (JuanJo's)weights

    # By setting strict to False we allow the model to load only the matching layers' weights
    if SALGAN_WEIGHTS == 'model_weights/gen_model.pt':
        model.salgan.load_state_dict(torch.load(SALGAN_WEIGHTS), strict=False)
    else:

        model.salgan.load_state_dict(torch.load(SALGAN_WEIGHTS)['state_dict'], strict=False)
    start_epoch = 1
else:
    # Load an entire pretrained model
    checkpoint = load_weights(model, args.pt_model)
    model.load_state_dict(checkpoint, strict=False)
    start_epoch = torch.load(args.pt_model, map_location='cpu')['epoch']
    #optimizer.load_state_dict(torch.load(args.pt_model, map_location='cpu')['optimizer'])

    print("Model loaded, commencing training from epoch {}".format(start_epoch))`

the type of "args.pt_model' is "str", so no matter how we set the pt_model is "False" or "True", It will always use the weight of SalGAN or SalBCE right?

Linardos commented 4 years ago

Hi Li sorry for taking this long to get back to you. The args.pt_model doesn't have a set type, it should be set to a string that is a path to your pretrained model or None if you're not loading anything and defaults to SalGAN/ SalBCE weights. But you are right there seems to be a bug there. It should be "if args.pt_model == None:" on line 155 Dunno when that changed or why, thanks for pointing it out.