Closed alfredcs closed 5 years ago
Try to modify this line (https://github.com/yxgeee/FD-GAN/blob/master/fdgan/model.py#L57) to:
state_dict['embed_model.classifier.bias'] = torch.FloatTensor([state_dict['embed_model.classifier.bias'][1]]).view(1,-1)
Ran into an error during pretrain. Please take a look when free. Thx!
Traceback (most recent call last): File "train.py", line 119, in
main()
File "train.py", line 53, in main
model = FDGANModel(opt)
File "/FD-GAN/fdgan/model.py", line 28, in init
self._init_models()
File "/FD-GAN/fdgan/model.py", line 59, in _init_models
self.net_Di.load_state_dict(state_dict).view(-1,2048)
File "/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SiameseNet:
size mismatch for embed_model.classifier.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1, 2048]).
Option is PreTrain
def _init_models(self): 40 self.net_G = CustomPoseGenerator(self.opt.pose_feature_size, 2048, self.opt.noise_feature_size, 41 dropout=self.opt.drop, norm_layer=self.norm_layer, fuse_mode=self.opt.fuse_mode, connect_layers=self.opt.connect_layers) 42 e_base_model = create(self.opt.arch, cut_at_pooling=True) 43 e_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=2) 44 self.net_E = SiameseNet(e_base_model, e_embed_model) 45 46 di_base_model = create(self.opt.arch, cut_at_pooling=True) 47 di_embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True, num_features=2048, num_classes=1) 48 self.net_Di = SiameseNet(di_base_model, di_embed_model) 49 self.net_Dp = NLayerDiscriminator(3+18, norm_layer=self.norm_layer) 50 51 if self.opt.stage==1: 52 init_weights(self.net_G) 53 init_weights(self.net_Dp) 54 #state_dict = remove_module_key(torch.load(self.opt.netE_pretrain)) 55 state_dict = remove_module_key(torch.load(self.opt.netE_pretrain)['state_dict']) 56 self.net_E.load_state_dict(state_dict) 57 state_dict['embed_model.classifier.weight'] = state_dict['embed_model.classifier.weight'][1] 58 state_dict['embed_model.classifier.bias'] = torch.FloatTensor([state_dict['embed_model.classifier.bias'][1]]) 59 self.net_Di.load_state_dict(state_dict) 60 elif self.opt.stage==2: 61 self._load_state_dict(self.net_E, self.opt.netE_pretrain) 62 self._load_state_dict(self.net_G, self.opt.netG_pretrain) 63 self._load_state_dict(self.net_Di, self.opt.netDi_pretrain) 64 self._load_state_dict(self.net_Dp, self.opt.netDp_pretrain) 65 else: 66 assert('unknown training stage')