princeton-vl / CornerNet

BSD 3-Clause "New" or "Revised" License
2.36k stars 475 forks source link

the train error in my dataset, the dimensions in the model are torch.Size([1]) and whose dimensions in the checkpoint are torch.Size([80]). #157

Open gengxiaomeng opened 4 years ago

gengxiaomeng commented 4 years ago

Dear all:

when I train my data, the category count is 1,my code is modified as follow:

1.modified one: models/CornerNet.py: line 72 class model(kp): def init(self, db): n = 5 dims = [256, 256, 384, 384, 384, 512] modules = [2, 2, 2, 2, 2, 4] out_dim = 1<---------------------- here need change to my category count 2.modified two: db/detection.py: line 8 self._configs["categories"] = 1 <---------------------- here need change to my category count

  1. modified three: config/CornerNet.json: line 45 "categories": 1 <---------------------- here need change to my category count 4 the error is reported as floow,: loading all datasets... using 4 threads loading from cache file: ./cache/coco_trainval2017.pkl loading annotations into memory... Done (t=0.13s) creating index... index created! loading from cache file: ./cache/coco_trainval2017.pkl loading annotations into memory... Done (t=0.15s) creating index... index created! loading from cache file: ./cache/coco_trainval2017.pkl loading annotations into memory... Done (t=0.17s) creating index... index created! loading from cache file: ./cache/coco_trainval2017.pkl loading annotations into memory... Done (t=0.18s) creating index... index created! loading from cache file: ./cache/coco_test2017.pkl loading annotations into memory... Done (t=0.12s) creating index... index created! system config... {'batch_size': 4, 'cache_dir': './cache', 'chunk_sizes': [4], 'config_dir': './config', 'data_dir': './data', 'data_rng': <mtrand.RandomState object at 0x7f28e24bd708>, 'dataset': 'MSCOCO', 'decay_rate': 10, 'display': 5, 'learning_rate': 0.00025, 'max_iter': 500000, 'nnet_rng': <mtrand.RandomState object at 0x7f28e24bd750>, 'opt_algo': 'adam', 'prefetch_size': 5, 'pretrain': './cache/nnet/CornerNet/CornerNet_500000.pkl', 'result_dir': './results', 'sampling_function': 'kp_detection', 'snapshot': 5000, 'snapshot_name': 'CornerNet', 'stepsize': 450000, 'test_split': 'testdev', 'train_split': 'trainval', 'val_iter': 100, 'val_split': 'minival', 'weight_decay': False, 'weight_decay_rate': 1e-05, 'weight_decay_type': 'l2'} db config... {'ae_threshold': 0.5, 'border': 128, 'categories': 1, 'data_aug': True, 'gaussian_bump': True, 'gaussian_iou': 0.3, 'gaussian_radius': -1, 'input_size': [511, 511], 'lighting': True, 'max_per_image': 100, 'merge_bbox': False, 'nms_algorithm': 'exp_soft_nms', 'nms_kernel': 3, 'nms_threshold': 0.5, 'output_sizes': [[128, 128]], 'rand_color': True, 'rand_crop': True, 'rand_pushes': False, 'rand_samples': False, 'rand_scale_max': 1.4, 'rand_scale_min': 0.6, 'rand_scale_step': 0.1, 'rand_scales': array([0.6, 0.7, 0.8, 0.9, 1. , 1.1, 1.2, 1.3]), 'special_crop': False, 'test_scales': [1], 'top_k': 100, 'weight_exp': 8} len of db: 10615 start prefetching data... start prefetching data... start prefetching data... shuffling indices... start prefetching data... start prefetching data... shuffling indices... shuffling indices... shuffling indices... building model... module_file: models.CornerNet shuffling indices... total parameters: 200954000 loading from pretrained model loading from ./cache/nnet/CornerNet/CornerNet_500000.pkl Traceback (most recent call last): File "/home/hhd/deeplearning/detection/github/CornerNet/train.py", line 195, in train(training_dbs, validation_db, args.start_iter) File "/home/hhd/deeplearning/detection/github/CornerNet/train.py", line 120, in train nnet.load_pretrained_params(pretrained_model) File "/home/hhd/deeplearning/detection/github/CornerNet/nnet/py_factory.py", line 110, in load_pretrained_params self.model.load_state_dict(params) File "/home/gxm/anaconda3/envs/CornerNet/lib/python3.6/site-packages/torch/nn/modules/module.py", line 721, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for DummyModule: While copying the parameter named "module.tl_heats.0.1.weight", whose dimensions in the model are torch.Size([1, 256, 1, 1]) and whose dimensions in the checkpoint are torch.Size([80, 256, 1, 1]). While copying the parameter named "module.tl_heats.0.1.bias", whose dimensions in the model are torch.Size([1]) and whose dimensions in the checkpoint are torch.Size([80]). While copying the parameter named "module.tl_heats.1.1.weight", whose dimensions in the model are torch.Size([1, 256, 1, 1]) and whose dimensions in the checkpoint are torch.Size([80, 256, 1, 1]). While copying the parameter named "module.tl_heats.1.1.bias", whose dimensions in the model are torch.Size([1]) and whose dimensions in the checkpoint are torch.Size([80]). While copying the parameter named "module.br_heats.0.1.weight", whose dimensions in the model are torch.Size([1, 256, 1, 1]) and whose dimensions in the checkpoint are torch.Size([80, 256, 1, 1]). While copying the parameter named "module.br_heats.0.1.bias", whose dimensions in the model are torch.Size([1]) and whose dimensions in the checkpoint are torch.Size([80]). While copying the parameter named "module.br_heats.1.1.weight", whose dimensions in the model are torch.Size([1, 256, 1, 1]) and whose dimensions in the checkpoint are torch.Size([80, 256, 1, 1]). While copying the parameter named "module.br_heats.1.1.bias", whose dimensions in the model are torch.Size([1]) and whose dimensions in the checkpoint are torch.Size([80]).
niovl commented 4 years ago

In py_factory.py, "def load_pretrained_params(self, pretrained_model)" should be modified like this:

def load_pretrained_params(self, pretrained_model):
    print("loading from {}".format(pretrained_model))
    with open(pretrained_model, "rb") as f:
        params = torch.load(f)
        params = {k: v for k, v in params.items()
                         if k in self.model.state_dict() 
                         and self.model.state_dict()[k].shape == v.shape}
        self.model.load_state_dict(params, strict=False)