fredzzhang / upt

[CVPR'22] Official PyTorch implementation for paper "Efficient Two-Stage Detection of Human–Object Interactions with a Novel Unary–Pairwise Transformer"
https://fredzzhang.com/unary-pairwise-transformers
BSD 3-Clause "New" or "Revised" License
144 stars 26 forks source link

list out of range and checkpoint's state_dict mismatch #78

Closed Dawn-LX closed 1 year ago

Dawn-LX commented 1 year ago

hello, 1). when I try the random init model, it runs into list index out of range for this line target_cls_idx = [self.object_class_to_target_class[obj.item()]

2). Then I tried to used the pre-trained UPT's checkpoint, but the state_dict mis-match.

My args' configure is correct. the default value of args.dataset is hicodet.

So I wander whether it is a bug or my problem ?

For 1)

python main.py --eval --backbone resnet101 --dilation --resume /path/to/model --data-root /storage/gaokaifeng

Namespace(alpha=0.5, aux_loss=True, backbone='resnet101', batch_size=2, bbox_loss_coef=5, box_score_thresh=0.2, cache=False, clip_max_norm=0.1, data_root='/storage/gaokaifeng', dataset='hicodet', dec_layers=6, device='cuda', dilation=True, dim_feedforward=2048, dropout=0.1, enc_layers=6, eos_coef=0.1, epochs=20, eval=True, fg_iou_thresh=0.5, gamma=0.2, giou_loss_coef=2, hidden_dim=256, lr_backbone=1e-05, lr_drop=10, lr_head=0.0001, max_instances=15, min_instances=3, nheads=8, num_queries=100, num_workers=2, output_dir='checkpoints', partitions=['train2015', 'test2015'], port='1234', position_embedding='sine', pre_norm=False, pretrained='', print_interval=500, repr_dim=512, resume='/path/to/model', sanity=False, seed=66, set_cost_bbox=5, set_cost_class=1, set_cost_giou=2, weight_decay=0.0001, world_size=1)
=> Rank 0: start from a randomly initialised model
  8%|████████████                                                                                                                                           | 766/9546 [01:29<17:09,  8.53it/s]
Traceback (most recent call last):
  File "main.py", line 210, in <module>
    mp.spawn(main, nprocs=args.world_size, args=(args,))
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/gaokaifeng/project/upt/main.py", line 99, in main
    ap = engine.test_hico(test_loader)
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/gaokaifeng/project/upt/utils.py", line 169, in test_hico
    output = net(inputs)
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 963, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/gaokaifeng/project/upt/upt.py", line 252, in forward
    logits, prior, bh, bo, objects, attn_maps = self.interaction_head(
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/gaokaifeng/project/upt/interaction_head.py", line 366, in forward
    prior_collated.append(self.compute_prior_scores(
  File "/home/gaokaifeng/project/upt/interaction_head.py", line 260, in compute_prior_scores
    target_cls_idx = [self.object_class_to_target_class[obj.item()]
  File "/home/gaokaifeng/project/upt/interaction_head.py", line 260, in <listcomp>
    target_cls_idx = [self.object_class_to_target_class[obj.item()]
IndexError: list index out of range

For 2):

(torch111) gaokaifeng@server1:~/project/upt$ python main.py \
>         --data-root /storage/gaokaifeng \
>         --eval \
>         --backbone resnet101 \
>         --dilation \
>         --resume checkpoints/upt-r101-dc5-hicodet.pt
Namespace(alpha=0.5, aux_loss=True, backbone='resnet101', batch_size=2, bbox_loss_coef=5, box_score_thresh=0.2, cache=False, clip_max_norm=0.1, data_root='/storage/gaokaifeng', dataset='hicodet', dec_layers=6, device='cuda', dilation=True, dim_feedforward=2048, dropout=0.1, enc_layers=6, eos_coef=0.1, epochs=20, eval=True, fg_iou_thresh=0.5, gamma=0.2, giou_loss_coef=2, hidden_dim=256, lr_backbone=1e-05, lr_drop=10, lr_head=0.0001, max_instances=15, min_instances=3, nheads=8, num_queries=100, num_workers=2, output_dir='checkpoints', partitions=['train2015', 'test2015'], port='1234', position_embedding='sine', pre_norm=False, pretrained='', print_interval=500, repr_dim=512, resume='checkpoints/upt-r101-dc5-hicodet.pt', sanity=False, seed=66, set_cost_bbox=5, set_cost_class=1, set_cost_giou=2, weight_decay=0.0001, world_size=1)
=> Rank 0: continue from saved checkpoint checkpoints/upt-r101-dc5-hicodet.pt
Traceback (most recent call last):
  File "main.py", line 210, in <module>
    mp.spawn(main, nprocs=args.world_size, args=(args,))
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/gaokaifeng/project/upt/main.py", line 76, in main
    upt.load_state_dict(checkpoint['model_state_dict'])
  File "/home/gaokaifeng/anaconda3/envs/torch111/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UPT:
        size mismatch for detector.class_embed.weight: copying a param with shape torch.Size([81, 256]) from checkpoint, the shape in current model is torch.Size([92, 256]).
        size mismatch for detector.class_embed.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([92]).
fredzzhang commented 1 year ago

Hi @Dawn-LX,

This seems to be resulted from the same issue in #77. If the classifier weights have 92 classes instead of 81, it suggests you are using the original DETR implementation, not the custom one attached in the submodule.

Simply navigate to the detr directory and checkout all changes until git status shows no more modifications.

Fred.

Dawn-LX commented 1 year ago

Hi @Dawn-LX,

This seems to be resulted from the same issue in #77. If the classifier weights have 92 classes instead of 81, it suggests you are using the original DETR implementation, not the custom one attached in the submodule.

Simply navigate to the detr directory and checkout all changes until git status shows no more modifications.

Fred.

Than you for your quick reply! I will further discuss this in issue #77 .