hkchengrex / Mask-Propagation

[CVPR 2021] MiVOS - Mask Propagation module. Reproduced STM (and better) with training code :star2:. Semi-supervised video object segmentation evaluation.
https://hkchengrex.github.io/MiVOS/
MIT License
127 stars 22 forks source link

Pre-training on the BL30K dataset after pre-training on static images #5

Closed nero1342 closed 3 years ago

nero1342 commented 3 years ago

As I see that in the pre-training on static images stage, the "single_object" in PropagationNetwork is True, so the MaskRGBEncoderSO is used. When I try to load the pre-trained of the above stage for the pre-training on the BL30K dataset or Main training, the "single_object" now is False and the model use MaskRGBEncoder instead. After that, the model can not load the model successfully. Here is the error: Traceback (most recent call last): File "train.py", line 68, in <module> total_iter = model.load_model(para['load_model']) File "/content/Mask-Propagation/model/model.py", line 180, in load_model self.PNet.module.load_state_dict(network) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1224, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for PropagationNetwork: size mismatch for mask_rgb_encoder.conv1.weight: copying a param with shape torch.Size([64, 4, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 5, 7, 7]).

So can you explain how can we fix it? Thank you so much.

longmalongma commented 3 years ago

As I see that in the pre-training on static images stage, the "single_object" in PropagationNetwork is True, so the MaskRGBEncoderSO is used. When I try to load the pre-trained of the above stage for the pre-training on the BL30K dataset or Main training, the "single_object" now is False and the model use MaskRGBEncoder instead. After that, the model can not load the model successfully. Here is the error: Traceback (most recent call last): File "train.py", line 68, in <module> total_iter = model.load_model(para['load_model']) File "/content/Mask-Propagation/model/model.py", line 180, in load_model self.PNet.module.load_state_dict(network) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1224, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for PropagationNetwork: size mismatch for mask_rgb_encoder.conv1.weight: copying a param with shape torch.Size([64, 4, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 5, 7, 7]).

So can you explain how can we fix it? Thank you so much.

Have you done your own pre-training static images stage ?What changes did you make to the original code?I have pre-trained on the static image which line is stuck and the server has not responded.

hkchengrex commented 3 years ago

@nero1342 You should use load_network, and the network transfer will be handled by this line.

load_model is used to continue interrupted training (i.e., load optimizer, scheduler, iteration as well), while load_network is used for loading the network weight only.

Sorry for the confusion, but I recall that we did use load_network in the sample training commands in the readme.

hkchengrex commented 3 years ago

@longmalongma Just to clarify -- supposedly no changes are required. I just tried and it started running within a minute.

nero1342 commented 3 years ago

@hkchengrex Thank you so much, I understood and have already fixed it.

nero1342 commented 3 years ago

@longmalongma You can try add tqdm to this line For example: for data in tqdm(train_loader, total = len(train_loader)): And then you can see that the progress of each epoch.

longmalongma commented 3 years ago

@longmalongma Just to clarify -- supposedly no changes are required. I just tried and it started running within a minute.

@longmalongma Just to clarify -- supposedly no changes are required. I just tried and it started running within a minute.

@longmalongma You can try add tqdm to this line For example: for data in tqdm(train_loader, total = len(train_loader)): And then you can see that the progress of each epoch.

Thanks for your reply, I will try it, do you have WeChat?Can you give me your WeChat?

longmalongma commented 3 years ago

@longmalongma Just to clarify -- supposedly no changes are required. I just tried and it started running within a minute.

@longmalongma Just to clarify -- supposedly no changes are required. I just tried and it started running within a minute.

@longmalongma You can try add tqdm to this line For example: for data in tqdm(train_loader, total = len(train_loader)): And then you can see that the progress of each epoch.

Thanks for your reply, I will try it, do you have WeChat?Can you give me your WeChat?

@hkchengrex hkchengrex

hkchengrex commented 3 years ago

@longmalongma It would be best if we can keep the discussion on GitHub so that everyone who might have the same problem can benefit from it.

longmalongma commented 3 years ago

@longmalongma It would be best if we can keep the discussion on GitHub so that everyone who might have the same problem can benefit from it.

Ok,thanks.