thuyngch / Human-Segmentation-PyTorch

Human segmentation models, training/inference code, and trained weights, implemented in PyTorch
558 stars 114 forks source link

can‘t load the trained model #34

Open upperblacksmith opened 3 years ago

upperblacksmith commented 3 years ago

when i load the trained model UNet_ResNet18.pth with torch.load(DeepLabV3Plus_ResNet18.pth). a error had occur,the detail as follows. 1 i will appreciate it if you can provide some advices. thanks a lot

jzx-gooner commented 3 years ago

save model like this and you can load your model ignore the ModuleNotFoundError or load your model and resave like this

        # Construct savedict
        #arch = type(self.model).__name__
        #state = {
            # 'arch': arch,
            # 'epoch': epoch,
            # 'logger': self.train_logger,
            #'state_dict': self.model.state_dict(),
            # 'optimizer': self.optimizer.state_dict(),
            # 'monitor_best': self.monitor_best,
            # 'config': self.config
        #}
        state_dict = self.model.state_dict()
        # Save checkpoint for each epoch
        if self.save_freq is not None:  # Use None mode to avoid over disk space with large models
            if epoch % self.save_freq == 0:
                filename = os.path.join(self.checkpoint_dir, 'epoch{}.pth'.format(epoch))
                torch.save(state_dict, filename)
                self.logger.info("Saving checkpoint at {}".format(filename))