dk-liang / FIDTM

[IEEE TMM] Focal Inverse Distance Transform Maps for Crowd Localization
MIT License
169 stars 41 forks source link

Testing issue #46

Open sulaimanvesal opened 8 months ago

sulaimanvesal commented 8 months ago

Hi,

Thank you for sharing the code. I tried to do a quick test after following all the data preprations. However, the output results are a bit strange IMG_1.jpg Gt 172.00 Pred 180049, specially in the count as you see below.

Am I missing something?

P.S: I am testing the model on CPU.

Best,

(pytorch_env) D:\Project1\FIDTM>python test.py --dataset ShanghaiA --pre ./model/ShanghaiA/model_best.pth --gpu_id 0
{'dataset': 'ShanghaiA', 'save_path': 'save_file/A_baseline', 'workers': 16, 'print_freq': 200, 'start_epoch': 0, 'epochs': 3000, 'pre': './model/ShanghaiA/model_best.pth', 'batch_size': 16, 'crop_size': 256, 'seed': 1, 'best_pred': 100000.0, 'gpu_id': '0', 'lr': 0.0001, 'weight_decay': 0.0005, 'preload_data': True, 'visual': False, 'video_path': None}
Using cpu
./model/ShanghaiA/model_best.pth
=> loading checkpoint './model/ShanghaiA/model_best.pth'
57.0989010989011 921
Pre_load dataset ......
begin test
IMG_1.jpg Gt 172.00 Pred 180049
IMG_10.jpg Gt 502.00 Pred 196417
IMG_100.jpg Gt 391.00 Pred 92455
IMG_101.jpg Gt 211.00 Pred 184704
IMG_102.jpg Gt 223.00 Pred 31672
IMG_103.jpg Gt 430.00 Pred 170330

image

mariosconsta commented 8 months ago

There's a problem with your weight loading. The authors have strict = False when loading their weights, this means that if your weights are not compatible, it won't throw you an error. Go change that to True and you will see what the problem is.

fyw1999 commented 5 months ago

Because the keys in pre-trained models' state dict is different from the keys of models defined in the code. First set strict = True and then I solved it in this manner:

    print("=> loading checkpoint '{}'".format(args['pre']))
    checkpoint = torch.load(args['pre'])
    pre_state_dict = checkpoint['state_dict']
    new_pre_state_dict = OrderedDict()
    for key in model.state_dict().keys():
        if "module."+key in pre_state_dict.keys():
        new_pre_state_dict[key] = pre_state_dict["module."+key]

    model.load_state_dict(new_pre_state_dict, strict=True)
    args['start_epoch'] = checkpoint['epoch']
    args['best_pred'] = checkpoint['best_prec1']