albumentations-team / autoalbument

AutoML for image augmentation. AutoAlbument uses the Faster AutoAugment algorithm to find optimal augmentation policies. Documentation - https://albumentations.ai/docs/autoalbument/
https://albumentations.ai/docs/autoalbument/
MIT License
198 stars 20 forks source link

How to load a trained model from a checkpoin #43

Open Violonur-PavelBI opened 1 year ago

Violonur-PavelBI commented 1 year ago

How to load a trained model from a checkpoint. if the model is set via settings classification_model: target: autoalbument.faster_autoaugment.models.ClassificationModel num_classes: 10 architecture: resnet50 #mobilenetv2_100 pretrained: False

from typing import Tuple
import segmentation_models_pytorch as smp
import timm
from torch import Tensor, nn
from torch.nn import Flatten
import pytorch_lightning as pl

I pulled the model class from autoalbument/autoalbument/faster_autoaugment/models/main_model.py

class BaseDiscriminator(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        raise NotImplementedError

class ClassificationModel(BaseDiscriminator):
    def __init__(self):
        super().__init__()
        self.base_model = timm.create_model("resnet50", pretrained=True)
        self.base_model.reset_classifier(10)
        self.classifier = self.base_model.get_classifier()
        num_features = self.classifier.in_features
        self.discriminator = nn.Sequential(
            nn.Linear(num_features, num_features), nn.ReLU(), nn.Linear(num_features, 1)
        )

    def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        x = self.base_model.forward_features(input)
        x = self.base_model.global_pool(x).flatten(1)
        return self.classifier(x), self.discriminator(x).view(-1)

created model

model_ft=ClassificationModel()

when loading weights into the model, an error appears

model_ft_afteFsAA=model_ft.load_from_checkpoint("/workspace/proj/autoalbument/examples/imageWoof/resnet50/outputs/2022-07-18/10-12-02/checkpoints/last.ckpt",batch_size=32,
                num_workers=10)

RuntimeError: Error(s) in loading state_dict for ClassificationModel: Missing key(s) in state_dict: "base_model.conv1.weight" ... "policy_model.sub_policies.99.stages.3.operations.13.saved_image_shape".