Closed WangZhi-wz closed 1 year ago
在训练过程中,tran.yaml关于预训练模型描述是这样的: model: model_name : mit_PLD_b4 is_pretrained : False pretrained_path : /mnt/DATA-1/DATA-2/Feilong/scformer/models/mit/mit_b4.pth from_epoch : 0
train.py加载预训练模型是这样的: if config['model']['is_pretrained']: model.load_state_dict(torch.load(config['model']['pretrained_path'])) logger.info("successfully add pretrained model")
但是我又发现模型文件mit_PLD-b2.py最后好像有加载预训练模型: def _init_weights(self): pretrained_dict = torch.load('models/mit/mit_b2.pth') model_dict = self.backbone.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.backbone.load_state_dict(model_dict) print("successfully loaded!!!!")
所有想问是否需要在train.yaml中设置加载预训练模型的参数,因为当我设置加载预训练模型后代码会报错: RuntimeError: Error(s) in loading state_dict for mit_PLD_b2: Missing key(s) in state_dict: "total_ops", "total_params"...................
Both are correct.
在训练过程中,tran.yaml关于预训练模型描述是这样的: model: model_name : mit_PLD_b4 is_pretrained : False pretrained_path : /mnt/DATA-1/DATA-2/Feilong/scformer/models/mit/mit_b4.pth from_epoch : 0
train.py加载预训练模型是这样的: if config['model']['is_pretrained']: model.load_state_dict(torch.load(config['model']['pretrained_path'])) logger.info("successfully add pretrained model")
但是我又发现模型文件mit_PLD-b2.py最后好像有加载预训练模型: def _init_weights(self): pretrained_dict = torch.load('models/mit/mit_b2.pth') model_dict = self.backbone.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.backbone.load_state_dict(model_dict) print("successfully loaded!!!!")
所有想问是否需要在train.yaml中设置加载预训练模型的参数,因为当我设置加载预训练模型后代码会报错: RuntimeError: Error(s) in loading state_dict for mit_PLD_b2: Missing key(s) in state_dict: "total_ops", "total_params"...................
你好,我也出现这个报错,请问你是如何解决的?谢谢
在训练过程中,tran.yaml关于预训练模型描述是这样的: model: model_name : mit_PLD_b4 is_pretrained : False pretrained_path : /mnt/DATA-1/DATA-2/Feilong/scformer/models/mit/mit_b4.pth from_epoch : 0
train.py加载预训练模型是这样的: if config['model']['is_pretrained']: model.load_state_dict(torch.load(config['model']['pretrained_path'])) logger.info("successfully add pretrained model")
但是我又发现模型文件mit_PLD-b2.py最后好像有加载预训练模型: def _init_weights(self): pretrained_dict = torch.load('models/mit/mit_b2.pth') model_dict = self.backbone.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.backbone.load_state_dict(model_dict) print("successfully loaded!!!!")
所有想问是否需要在train.yaml中设置加载预训练模型的参数,因为当我设置加载预训练模型后代码会报错: RuntimeError: Error(s) in loading state_dict for mit_PLD_b2: Missing key(s) in state_dict: "total_ops", "total_params"...................