qubvel-org / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.36k stars 1.65k forks source link

Can't load saved model after training #898

Open meowmoewrainbow opened 1 month ago

meowmoewrainbow commented 1 month ago

Hi, I followed this notebook to train a new model and saved the checkpoint as Pytorch Lightning checkpoint. After that, I want to load the checkpoint but it can't be loaded.

The code:

class PolypModel(pl.LightningModule):

    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # preprocessing parameteres for image
        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))

        # for image segmentation dice loss could be the best first choice
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

    def forward(self, image):
        # normalize image here
        image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage):

        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)

        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")

        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }

        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

model = PolypModel("DeepLabV3Plus", "efficientnet-b0", in_channels=3, out_classes=1)

checkpoint = torch.load('/datadrive/thaonp47/WeakPolyp/deeplabv3plus-eff/epoch=04-valid_per_image_iou=0.64.ckpt')
model.load_state_dict(checkpoint)

And the errors: Traceback (most recent call last): File "test.py", line 205, in <module> model.load_state_dict(checkpoint) File "/home/azureuser/anaconda3/envs/thaonp47/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for PolypModel: Missing key(s) in state_dict: "std", "mean", "model.encoder._conv_stem.weight", "model.encoder._bn0.weight", "model.encoder._bn0.bias", "model.encoder._bn0.running_mean", "model.encoder._bn0.running_var", "model.encoder._blocks.0._depthwise_conv.weight", "model.encoder._blocks.0._bn1.weight", "model.encoder._blocks.0._bn1.bias", "model.encoder._blocks.0._bn1.running_mean", "model.encoder._blocks.0._bn1.running_var", "model.encoder._blocks.0._se_reduce.weight", "model.encoder._blocks.0._se_reduce.bias", "model.encoder._blocks.0._se_expand.weight", "model.encoder._blocks.0._se_expand.bias", "model.encoder._blocks.0._project_conv.weight", "model.encoder._blocks.0._bn2.weight", "model.encoder._blocks.0._bn2.bias", "model.encoder._blocks.0._bn2.running_mean", "model.encoder._blocks.0._bn2.running_var", "model.encoder._blocks.1._expand_conv.weight", "model.encoder._blocks.1._bn0.weight", "model.encoder._blocks.1._bn0.bias", "model.encoder._blocks.1._bn0.running_mean", "model.encoder._blocks.1._bn0.running_var", "model.encoder._blocks.1._depthwise_conv.weight", "model.encoder._blocks.1._bn1.weight", "model.encoder._blocks.1._bn1.bias", "model.encoder._blocks.1._bn1.running_mean", "model.encoder._blocks.1._bn1.running_var", "model.encoder._blocks.1._se_reduce.weight", "model.encoder._blocks.1._se_reduce.bias", "model.encoder._blocks.1._se_expand.weight", "model.encoder._blocks.1._se_expand.bias", "model.encoder._blocks.1._project_conv.weight", "model.encoder._blocks.1._bn2.weight", "model.encoder._blocks.1._bn2.bias", "model.encoder._blocks.1._bn2.running_mean", "model.encoder._blocks.1._bn2.running_var", "model.encoder._blocks.2._expand_conv.weight", "model.encoder._blocks.2._bn0.weight", "model.encoder._blocks.2._bn0.bias", "model.encoder._blocks.2._bn0.running_mean", "model.encoder._blocks.2._bn0.running_var", "model.encoder._blocks.2._depthwise_conv.weight", "model.encoder._blocks.2._bn1.weight", "model.encoder._blocks.2._bn1.bias", "model.encoder._blocks.2._bn1.running_mean", "model.encoder._blocks.2._bn1.running_var", "model.encoder._blocks.2._se_reduce.weight", "model.encoder._blocks.2._se_reduce.bias", "model.encoder._blocks.2._se_expand.weight", "model.encoder._blocks.2._se_expand.bias", "model.encoder._blocks.2._project_conv.weight", "model.encoder._blocks.2._bn2.weight", "model.encoder._blocks.2._bn2.bias", "model.encoder._blocks.2._bn2.running_mean", "model.encoder._blocks.2._bn2.running_var", "model.encoder._blocks.3._expand_conv.weight", "model.encoder._blocks.3._bn0.weight", "model.encoder._blocks.3._bn0.bias", "model.encoder._blocks.3._bn0.running_mean", "model.encoder._blocks.3._bn0.running_var", "model.encoder._blocks.3._depthwise_conv.weight", "model.encoder._blocks.3._bn1.weight", "model.encoder._blocks.3._bn1.bias", "model.encoder._blocks.3._bn1.running_mean", "model.encoder._blocks.3._bn1.running_var", "model.encoder._blocks.3._se_reduce.weight", "model.encoder._blocks.3._se_reduce.bias", "model.encoder._blocks.3._se_expand.weight", "model.encoder._blocks.3._se_expand.bias", "model.encoder._blocks.3._project_conv.weight", "model.encoder._blocks.3._bn2.weight", "model.encoder._blocks.3._bn2.bias", "model.encoder._blocks.3._bn2.running_mean", "model.encoder._blocks.3._bn2.running_var", "model.encoder._blocks.4._expand_conv.weight", "model.encoder._blocks.4._bn0.weight", "model.encoder._blocks.4._bn0.bias", "model.encoder._blocks.4._bn0.running_mean", "model.encoder._blocks.4._bn0.running_var", "model.encoder._blocks.4._depthwise_conv.weight", "model.encoder._blocks.4._bn1.weight", "model.encoder._blocks.4._bn1.bias", "model.encoder._blocks.4._bn1.running_mean", "model.encoder._blocks.4._bn1.running_var", "model.encoder._blocks.4._se_reduce.weight", "model.encoder._blocks.4._se_reduce.bias", "model.encoder._blocks.4._se_expand.weight", "model.encoder._blocks.4._se_expand.bias", "model.encoder._blocks.4._project_conv.weight", "model.encoder._blocks.4._bn2.weight", "model.encoder._blocks.4._bn2.bias", "model.encoder._blocks.4._bn2.running_mean", "model.encoder._blocks.4._bn2.running_var", "model.encoder._blocks.5._expand_conv.weight", "model.encoder._blocks.5._bn0.weight", "model.encoder._blocks.5._bn0.bias", "model.encoder._blocks.5._bn0.running_mean", "model.encoder._blocks.5._bn0.running_var", "model.encoder._blocks.5._depthwise_conv.weight", "model.encoder._blocks.5._bn1.weight", "model.encoder._blocks.5._bn1.bias", "model.encoder._blocks.5._bn1.running_mean", "model.encoder._blocks.5._bn1.running_var", "model.encoder._blocks.5._se_reduce.weight", "model.encoder._blocks.5._se_reduce.bias", "model.encoder._blocks.5._se_expand.weight", "model.encoder._blocks.5._se_expand.bias", "model.encoder._blocks.5._project_conv.weight", "model.encoder._blocks.5._bn2.weight", "model.encoder._blocks.5._bn2.bias", "model.encoder._blocks.5._bn2.running_mean", "model.encoder._blocks.5._bn2.running_var", "model.encoder._blocks.6._expand_conv.weight", "model.encoder._blocks.6._bn0.weight", "model.encoder._blocks.6._bn0.bias", "model.encoder._blocks.6._bn0.running_mean", "model.encoder._blocks.6._bn0.running_var", "model.encoder._blocks.6._depthwise_conv.weight", "model.encoder._blocks.6._bn1.weight", "model.encoder._blocks.6._bn1.bias", "model.encoder._blocks.6._bn1.running_mean", "model.encoder._blocks.6._bn1.running_var", "model.encoder._blocks.6._se_reduce.weight", "model.encoder._blocks.6._se_reduce.bias", "model.encoder._blocks.6._se_expand.weight", "model.encoder._blocks.6._se_expand.bias", "model.encoder._blocks.6._project_conv.weight", "model.encoder._blocks.6._bn2.weight", "model.encoder._blocks.6._bn2.bias", "model.encoder._blocks.6._bn2.running_mean", "model.encoder._blocks.6._bn2.running_var", "model.encoder._blocks.7._expand_conv.weight", "model.encoder._blocks.7._bn0.weight", "model.encoder._blocks.7._bn0.bias", "model.encoder._blocks.7._bn0.running_mean", "model.encoder._blocks.7._bn0.running_var", "model.encoder._blocks.7._depthwise_conv.weight", "model.encoder._blocks.7._bn1.weight", "model.encoder._blocks.7._bn1.bias", "model.encoder._blocks.7._bn1.running_mean", "model.encoder._blocks.7._bn1.running_var", "model.encoder._blocks.7._se_reduce.weight", "model.encoder._blocks.7._se_reduce.bias", "model.encoder._blocks.7._se_expand.weight", "model.encoder._blocks.7._se_expand.bias", "model.encoder._blocks.7._project_conv.weight", "model.encoder._blocks.7._bn2.weight", "model.encoder._blocks.7._bn2.bias", "model.encoder._blocks.7._bn2.running_mean", "model.encoder._blocks.7._bn2.running_var", "model.encoder._blocks.8._expand_conv.weight", "model.encoder._blocks.8._bn0.weight", "model.encoder._blocks.8._bn0.bias", "model.encoder._blocks.8._bn0.running_mean", "model.encoder._blocks.8._bn0.running_var", "model.encoder._blocks.8._depthwise_conv.weight", "model.encoder._blocks.8._bn1.weight", "model.encoder._blocks.8._bn1.bias", "model.encoder._blocks.8._bn1.running_mean", "model.encoder._blocks.8._bn1.running_var", "model.encoder._blocks.8._se_reduce.weight", "model.encoder._blocks.8._se_reduce.bias", "model.encoder._blocks.8._se_expand.weight", "model.encoder._blocks.8._se_expand.bias", "model.encoder._blocks.8._project_conv.weight", "model.encoder._blocks.8._bn2.weight", "model.encoder._blocks.8._bn2.bias", "model.encoder._blocks.8._bn2.running_mean", "model.encoder._blocks.8._bn2.running_var", "model.encoder._blocks.9._expand_conv.weight", "model.encoder._blocks.9._bn0.weight", "model.encoder._blocks.9._bn0.bias", "model.encoder._blocks.9._bn0.running_mean", "model.encoder._blocks.9._bn0.running_var", "model.encoder._blocks.9._depthwise_conv.weight", "model.encoder._blocks.9._bn1.weight", "model.encoder._blocks.9._bn1.bias", "model.encoder._blocks.9._bn1.running_mean", "model.encoder._blocks.9._bn1.running_var", "model.encoder._blocks.9._se_reduce.weight", "model.encoder._blocks.9._se_reduce.bias", "model.encoder._blocks.9._se_expand.weight", "model.encoder._blocks.9._se_expand.bias", "model.encoder._blocks.9._project_conv.weight", "model.encoder._blocks.9._bn2.weight", "model.encoder._blocks.9._bn2.bias", "model.encoder._blocks.9._bn2.running_mean", "model.encoder._blocks.9._bn2.running_var", "model.encoder._blocks.10._expand_conv.weight", "model.encoder._blocks.10._bn0.weight", "model.encoder._blocks.10._bn0.bias", "model.encoder._blocks.10._bn0.running_mean", "model.encoder._blocks.10._bn0.running_var", "model.encoder._blocks.10._depthwise_conv.weight", "model.encoder._blocks.10._bn1.weight", "model.encoder._blocks.10._bn1.bias", "model.encoder._blocks.10._bn1.running_mean", "model.encoder._blocks.10._bn1.running_var", "model.encoder._blocks.10._se_reduce.weight", "model.encoder._blocks.10._se_reduce.bias", "model.encoder._blocks.10._se_expand.weight", "model.encoder._blocks.10._se_expand.bias", "model.encoder._blocks.10._project_conv.weight", "model.encoder._blocks.10._bn2.weight", "model.encoder._blocks.10._bn2.bias", "model.encoder._blocks.10._bn2.running_mean", "model.encoder._blocks.10._bn2.running_var", "model.encoder._blocks.11._expand_conv.weight", "model.encoder._blocks.11._bn0.weight", "model.encoder._blocks.11._bn0.bias", "model.encoder._blocks.11._bn0.running_mean", "model.encoder._blocks.11._bn0.running_var", "model.encoder._blocks.11._depthwise_conv.weight", "model.encoder._blocks.11._bn1.weight", "model.encoder._blocks.11._bn1.bias", "model.encoder._blocks.11._bn1.running_mean", "model.encoder._blocks.11._bn1.running_var", "model.encoder._blocks.11._se_reduce.weight", "model.encoder._blocks.11._se_reduce.bias", "model.encoder._blocks.11._se_expand.weight", "model.encoder._blocks.11._se_expand.bias", "model.encoder._blocks.11._project_conv.weight", "model.encoder._blocks.11._bn2.weight", "model.encoder._blocks.11._bn2.bias", "model.encoder._blocks.11._bn2.running_mean", "model.encoder._blocks.11._bn2.running_var", "model.encoder._blocks.12._expand_conv.weight", "model.encoder._blocks.12._bn0.weight", "model.encoder._blocks.12._bn0.bias", "model.encoder._blocks.12._bn0.running_mean", "model.encoder._blocks.12._bn0.running_var", "model.encoder._blocks.12._depthwise_conv.weight", "model.encoder._blocks.12._bn1.weight", "model.encoder._blocks.12._bn1.bias", "model.encoder._blocks.12._bn1.running_mean", "model.encoder._blocks.12._bn1.running_var", "model.encoder._blocks.12._se_reduce.weight", "model.encoder._blocks.12._se_reduce.bias", "model.encoder._blocks.12._se_expand.weight", "model.encoder._blocks.12._se_expand.bias", "model.encoder._blocks.12._project_conv.weight", "model.encoder._blocks.12._bn2.weight", "model.encoder._blocks.12._bn2.bias", "model.encoder._blocks.12._bn2.running_mean", "model.encoder._blocks.12._bn2.running_var", "model.encoder._blocks.13._expand_conv.weight", "model.encoder._blocks.13._bn0.weight", "model.encoder._blocks.13._bn0.bias", "model.encoder._blocks.13._bn0.running_mean", "model.encoder._blocks.13._bn0.running_var", "model.encoder._blocks.13._depthwise_conv.weight", "model.encoder._blocks.13._bn1.weight", "model.encoder._blocks.13._bn1.bias", "model.encoder._blocks.13._bn1.running_mean", "model.encoder._blocks.13._bn1.running_var", "model.encoder._blocks.13._se_reduce.weight", "model.encoder._blocks.13._se_reduce.bias", "model.encoder._blocks.13._se_expand.weight", "model.encoder._blocks.13._se_expand.bias", "model.encoder._blocks.13._project_conv.weight", "model.encoder._blocks.13._bn2.weight", "model.encoder._blocks.13._bn2.bias", "model.encoder._blocks.13._bn2.running_mean", "model.encoder._blocks.13._bn2.running_var", "model.encoder._blocks.14._expand_conv.weight", "model.encoder._blocks.14._bn0.weight", "model.encoder._blocks.14._bn0.bias", "model.encoder._blocks.14._bn0.running_mean", "model.encoder._blocks.14._bn0.running_var", "model.encoder._blocks.14._depthwise_conv.weight", "model.encoder._blocks.14._bn1.weight", "model.encoder._blocks.14._bn1.bias", "model.encoder._blocks.14._bn1.running_mean", "model.encoder._blocks.14._bn1.running_var", "model.encoder._blocks.14._se_reduce.weight", "model.encoder._blocks.14._se_reduce.bias", "model.encoder._blocks.14._se_expand.weight", "model.encoder._blocks.14._se_expand.bias", "model.encoder._blocks.14._project_conv.weight", "model.encoder._blocks.14._bn2.weight", "model.encoder._blocks.14._bn2.bias", "model.encoder._blocks.14._bn2.running_mean", "model.encoder._blocks.14._bn2.running_var", "model.encoder._blocks.15._expand_conv.weight", "model.encoder._blocks.15._bn0.weight", "model.encoder._blocks.15._bn0.bias", "model.encoder._blocks.15._bn0.running_mean", "model.encoder._blocks.15._bn0.running_var", "model.encoder._blocks.15._depthwise_conv.weight", "model.encoder._blocks.15._bn1.weight", "model.encoder._blocks.15._bn1.bias", "model.encoder._blocks.15._bn1.running_mean", "model.encoder._blocks.15._bn1.running_var", "model.encoder._blocks.15._se_reduce.weight", "model.encoder._blocks.15._se_reduce.bias", "model.encoder._blocks.15._se_expand.weight", "model.encoder._blocks.15._se_expand.bias", "model.encoder._blocks.15._project_conv.weight", "model.encoder._blocks.15._bn2.weight", "model.encoder._blocks.15._bn2.bias", "model.encoder._blocks.15._bn2.running_mean", "model.encoder._blocks.15._bn2.running_var", "model.encoder._conv_head.weight", "model.encoder._bn1.weight", "model.encoder._bn1.bias", "model.encoder._bn1.running_mean", "model.encoder._bn1.running_var", "model.decoder.aspp.0.convs.0.0.weight", "model.decoder.aspp.0.convs.0.1.weight", "model.decoder.aspp.0.convs.0.1.bias", "model.decoder.aspp.0.convs.0.1.running_mean", "model.decoder.aspp.0.convs.0.1.running_var", "model.decoder.aspp.0.convs.1.0.0.weight", "model.decoder.aspp.0.convs.1.0.1.weight", "model.decoder.aspp.0.convs.1.1.weight", "model.decoder.aspp.0.convs.1.1.bias", "model.decoder.aspp.0.convs.1.1.running_mean", "model.decoder.aspp.0.convs.1.1.running_var", "model.decoder.aspp.0.convs.2.0.0.weight", "model.decoder.aspp.0.convs.2.0.1.weight", "model.decoder.aspp.0.convs.2.1.weight", "model.decoder.aspp.0.convs.2.1.bias", "model.decoder.aspp.0.convs.2.1.running_mean", "model.decoder.aspp.0.convs.2.1.running_var", "model.decoder.aspp.0.convs.3.0.0.weight", "model.decoder.aspp.0.convs.3.0.1.weight", "model.decoder.aspp.0.convs.3.1.weight", "model.decoder.aspp.0.convs.3.1.bias", "model.decoder.aspp.0.convs.3.1.running_mean", "model.decoder.aspp.0.convs.3.1.running_var", "model.decoder.aspp.0.convs.4.1.weight", "model.decoder.aspp.0.convs.4.2.weight", "model.decoder.aspp.0.convs.4.2.bias", "model.decoder.aspp.0.convs.4.2.running_mean", "model.decoder.aspp.0.convs.4.2.running_var", "model.decoder.aspp.0.project.0.weight", "model.decoder.aspp.0.project.1.weight", "model.decoder.aspp.0.project.1.bias", "model.decoder.aspp.0.project.1.running_mean", "model.decoder.aspp.0.project.1.running_var", "model.decoder.aspp.1.0.weight", "model.decoder.aspp.1.1.weight", "model.decoder.aspp.2.weight", "model.decoder.aspp.2.bias", "model.decoder.aspp.2.running_mean", "model.decoder.aspp.2.running_var", "model.decoder.block1.0.weight", "model.decoder.block1.1.weight", "model.decoder.block1.1.bias", "model.decoder.block1.1.running_mean", "model.decoder.block1.1.running_var", "model.decoder.block2.0.0.weight", "model.decoder.block2.0.1.weight", "model.decoder.block2.1.weight", "model.decoder.block2.1.bias", "model.decoder.block2.1.running_mean", "model.decoder.block2.1.running_var", "model.segmentation_head.0.weight", "model.segmentation_head.0.bias". Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "callbacks", "optimizer_states", "lr_schedulers".

qubvel commented 1 month ago

Hi, try to choose state dict from the checkpoint to load: model.load_state_dict(checkpoint["state_dict"])

YUHSINCHENG1230 commented 1 week ago

I also try the code model.load_state_dict(checkpoint["state_dict"]) but still have same problem how can i do next?

qubvel commented 1 week ago

It would be easier to understand if you look at the loaded torch state dict keys and mode's state dict keys, thus you can figure out the problem with loading.