DonaldRR / SimpleNet

MIT License
402 stars 59 forks source link

Issue in the test function in #57

Open taurou opened 7 months ago

taurou commented 7 months ago

I think there's a mismatch between the model weights saving and the model weights loading of the code

In the load function there's: def test(self, training_data, test_data, save_segmentation_images):

    ckpt_path = os.path.join(self.ckpt_dir, "**models.ckpt**")
    if os.path.exists(ckpt_path):
        state_dicts = torch.load(ckpt_path, map_location=self.device)
        if "pretrained_enc" in state_dicts:
            self.**feature_enc**.load_state_dict(state_dicts["**pretrained_enc**"])
        if "pretrained_dec" in state_dicts:
            self.**feature_dec**.load_state_dict(state_dicts["**pretrained_dec**"])

While in the train function there's this for the loading part:

   state_dict = {}
         ckpt_path = os.path.join(self.ckpt_dir, "**ckpt.pth**")
       if os.path.exists(ckpt_path):
          state_dict = torch.load(ckpt_path, map_location=self.device)
          if 'discriminator' in state_dict:
              self.**discriminator**.load_state_dict(state_dict['**discriminator**'])
              if "pre_projection" in state_dict:
                  self.**pre_projection**.load_state_dict(state_dict["pre_projection"])
              else:
              self.load_state_dict(state_dict, strict=False)

And this is called when the best model is saved:

    def update_state_dict(d):

        state_dict["discriminator"] = OrderedDict({
            k:v.detach().cpu() 
            for k, v in self.discriminator.state_dict().items()})
        if self.pre_proj > 0:
            state_dict["pre_projection"] = OrderedDict({
                k:v.detach().cpu() 
                for k, v in self.pre_projection.state_dict().items()})

          update_state_dict(state_dict)

         torch.save(state_dict, ckpt_path)
         return best_record

There's a mismatch in the path and maybe a missing corrispondence between the dictionary entries in the test function.