I think there's a mismatch between the model weights saving and the model weights loading of the code
In the load function there's:
def test(self, training_data, test_data, save_segmentation_images):
ckpt_path = os.path.join(self.ckpt_dir, "**models.ckpt**")
if os.path.exists(ckpt_path):
state_dicts = torch.load(ckpt_path, map_location=self.device)
if "pretrained_enc" in state_dicts:
self.**feature_enc**.load_state_dict(state_dicts["**pretrained_enc**"])
if "pretrained_dec" in state_dicts:
self.**feature_dec**.load_state_dict(state_dicts["**pretrained_dec**"])
While in the train function there's this for the loading part:
state_dict = {}
ckpt_path = os.path.join(self.ckpt_dir, "**ckpt.pth**")
if os.path.exists(ckpt_path):
state_dict = torch.load(ckpt_path, map_location=self.device)
if 'discriminator' in state_dict:
self.**discriminator**.load_state_dict(state_dict['**discriminator**'])
if "pre_projection" in state_dict:
self.**pre_projection**.load_state_dict(state_dict["pre_projection"])
else:
self.load_state_dict(state_dict, strict=False)
And this is called when the best model is saved:
def update_state_dict(d):
state_dict["discriminator"] = OrderedDict({
k:v.detach().cpu()
for k, v in self.discriminator.state_dict().items()})
if self.pre_proj > 0:
state_dict["pre_projection"] = OrderedDict({
k:v.detach().cpu()
for k, v in self.pre_projection.state_dict().items()})
update_state_dict(state_dict)
torch.save(state_dict, ckpt_path)
return best_record
There's a mismatch in the path and maybe a missing corrispondence between the dictionary entries in the test function.
I think there's a mismatch between the model weights saving and the model weights loading of the code
In the load function there's: def test(self, training_data, test_data, save_segmentation_images):
While in the train function there's this for the loading part:
And this is called when the best model is saved:
There's a mismatch in the path and maybe a missing corrispondence between the dictionary entries in the test function.