openvinotoolkit / anomalib

An anomaly detection library comprising state-of-the-art algorithms and features such as experiment management, hyper-parameter optimization, and edge inference.
https://anomalib.readthedocs.io/en/latest/
Apache License 2.0
3.86k stars 685 forks source link

📋 [TASK] How I can use a local backbone ? #2439

Open davy-blavette opened 5 days ago

davy-blavette commented 5 days ago

Describe the task

How can I load the backbone model locally (or pretrained), each time it searches Hugging Face, which is rather painful and not really possible in my case as I don't have a connection:

Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)

I downloaded model.safetensors > resnet18.safetensors

https://huggingface.co/timm/resnet18.a1_in1k/tree/main

and tried to adapt the code :

        local_model_path = os.path.join(self.file_handler.pretrained, "resnet18.safetensors")

        if os.path.exists(local_model_path):
            self.model = Padim(
                backbone=self.file_handler.args.backbone,
                layers=["layer1", "layer2", "layer3"]
            )

            with safe_open(local_model_path, framework="pt", device="cpu") as f:
                state_dict = {k: f.get_tensor(k) for k in f.keys()}

            if hasattr(self.model, 'feature_extractor'):
                self.model.feature_extractor.load_state_dict(state_dict)
            else:
                logging.warning(
                    "No attribut 'feature_extractor'. Loading local done")
        else:
            self.model = Padim(
                backbone=self.file_handler.args.backbone,
                layers=["layer1", "layer2", "layer3"],
                pre_trained=True
            )

but without success...

Acceptance Criteria

code that works with a local pretrained/backbone file, I don't want any downloads

Priority

High

Related Epic

No response

Estimated Time

No response

Current Status

Not Started

Additional Information

No response

davy-blavette commented 4 days ago

I found in the source code :

        # Extract backbone-name and weight-URI from the backbone string.
        if ‘__AT__’ in backbone:
            backbone, uri = backbone.split(‘__AT__’)
            pretrained_cfg = timm.models.registry.get_pretrained_cfg(backbone)
            # Override pretrained_cfg[‘url’] to use different pretrained weights.
            pretrained_cfg[‘url’] = uri

So I adapted my code:

        local_weights_path = f ‘file://{self.file_handler.pretrained}/model.safetensors’
        backbone = f"{self.file_handler.args.backbone}__AT__{local_weights_path}’

        self.model = Padim(
            backbone=backbone,
            layers=[‘layer1’, ‘layer2’, ‘layer3’]
        )

but I have this problem :

AttributeError: module ‘timm.models’ has no attribute ‘registry’

so I downgraded Timm :

pip install ‘timm==0.6.13’

but this causes another problem:

    magic_number = pickle_module.load(f, **pickle_load_args)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
UnpicklingError: invalid load key, ‘*’.
alexriedel1 commented 2 days ago

If you want to use timm model weights, you can download the weights and save them under ~/.cache/torch/hub/checkpoints (or whatever your directory for cached timm models is).

davy-blavette commented 2 days ago

or change the os environement: os.environ[‘HF_HOME’] = PRETRAINED....which I did but it's a lousy solution...

alexriedel1 commented 2 days ago

or change the os environement: os.environ[‘HF_HOME’] = PRETRAINED....which I did but it's a lousy solution...

hm i see. one thing you could try:

model = Padim(
                backbone="resnet18",
                layers=["layer1", "layer2", "layer3"],
                pre_trained=False
            )

model.model.feature_extractor.feature_extractor = timm.create_model(
    "resnet18",
    pretrained=True,
    pretrained_cfg_overlay=dict(file='C:/Users/Alex/Desktop/resnet18-f37072fd.pth'),
    features_only=True,
    exportable=True,
    out_indices=model.model.feature_extractor.idx
    )

This doesn't need an internet connection. You simply patch the feature extractor of Padim with a new one initialized from a local model file