ylabbe / cosypose

Code for "CosyPose: Consistent multi-view multi-object 6D pose estimation", ECCV 2020.
MIT License
301 stars 89 forks source link

Training Cosypose on One single object #47

Open salimkhazem opened 3 years ago

salimkhazem commented 3 years ago

Hello, Actually, i want to train Cosypose for one object, how can i do it ? Which are the steps to do it ?

Thanks a lot

hannes56a commented 3 years ago

Hi, i want to train on a single object too. Iḿ currently tried to train the detector only but got an error if i load the pretrained checkpoint in "train_detector.py" in this part:

if args.run_id_pretrain is not None:
        pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar'
        logger.info(f'Using pretrained model from {pretrain_path}.')
        model.load_state_dict(torch.load(pretrain_path)['state_dict'])

The error is

Error(s) in loading state_dict for DetectorMaskRCNN: size mismatch for roi_heads.box_predictor.cls_score.weight: copying a param with shape torch.Size([31, 1024]) from checkpoint, the shape in current model is torch.Size([2, 1024]).

I know its because i use the pretrained checkpoint from tless with 30 classes (+background)..., but isnt it possible to reduce the classes to 1 (+ background)?

As a workaround i trained 30 identicly classes ;-)

hannes56a commented 3 years ago

@azad96 Perhaps opu have an idea?

azad96 commented 3 years ago

@hannes56a, You can load and analyze the last layers of that checkpoint. Then, you can remove the layer with 30 classes and add your own instead. I've done a similar thing for the pose network which was below.

ckpt = torch.load(pretrain_path)['state_dict']  
dummy_weight = torch.nn.Linear(1536, 9, bias=True)
dummy_weight.cuda()
dummy_state_dict = dummy_weight.state_dict() 
ckpt["pose_fc.weight"] = dummy_state_dict['weight']
ckpt["pose_fc.bias"] = dummy_state_dict['bias']
model.load_state_dict(ckpt)

What I did is randomizing the last fully connected layer of the pretrained tless pose model because I wanted to get rid of the high-level features.

hannes56a commented 3 years ago

@azad96 Thank you so much. That helps me to solve the problem. Now i can train the detector with one class :-)

Here is the code snippet of my changes in the train_detector.py

if args.run_id_pretrain is not None:
        pretrain_path = EXP_DIR / args.run_id_pretrain / 'checkpoint.pth.tar'
        logger.info(f'Using pretrained model from {pretrain_path}.')
        #model.load_state_dict(torch.load(pretrain_path)['state_dict'])

        ckpt = torch.load(pretrain_path)['state_dict']
        #model.load_state_dict(ckpt)

        print(ckpt.keys())
        print(model)

        dummy_weight_cls_score = torch.nn.Linear(1024, 2, bias=True)
        dummy_weight_cls_score.cuda()
        dummy_state_dict_cls_score = dummy_weight_cls_score.state_dict() 
        ckpt["roi_heads.box_predictor.cls_score.weight"] = dummy_state_dict_cls_score['weight']
        ckpt["roi_heads.box_predictor.cls_score.bias"] = dummy_state_dict_cls_score['bias']

        dummy_weight_bbox_pred = torch.nn.Linear(1024, 8, bias=True)
        dummy_weight_bbox_pred.cuda()
        dummy_state_dict_bbox_pred = dummy_weight_bbox_pred.state_dict() 
        ckpt["roi_heads.box_predictor.bbox_pred.weight"] = dummy_state_dict_bbox_pred['weight']
        ckpt["roi_heads.box_predictor.bbox_pred.bias"] = dummy_state_dict_bbox_pred['bias']

        dummy_weight_mask = torch.nn.Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
        dummy_weight_mask.cuda()
        dummy_state_dict_mask = dummy_weight_mask.state_dict()         
        ckpt["roi_heads.mask_predictor.mask_fcn_logits.weight"] = dummy_state_dict_mask['weight']
        ckpt["roi_heads.mask_predictor.mask_fcn_logits.bias"] = dummy_state_dict_mask['bias']

        model.load_state_dict(ckpt)