Li-yachuan / NBED

Code of paper "A new baseline for edge detection: Make Encoder-Decoder great again"
30 stars 3 forks source link

Inference script #3

Open nikolamilovic-ft opened 1 month ago

nikolamilovic-ft commented 1 month ago

Could you add an easy to run inference script in the repository? I'm having trouble running the edge extraction on arbitrary input image using the checkpoint model you provided in the README.

Li-yachuan commented 1 month ago

what is the error?

---Original--- From: "Nikola @.> Date: Mon, Oct 21, 2024 11:00 AM To: @.>; Cc: @.***>; Subject: [Li-yachuan/NBED] Inference script (Issue #3)

Could you add an easy to run inference script in the repository? I'm having trouble running the edge extraction on arbitrary input image using the checkpoint model you provided in the README.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you are subscribed to this thread.Message ID: @.***>

nikolamilovic-ft commented 1 month ago

Here is _myinference.py script:

import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import argparse
import os

# Import the model definition
from model.basemodel import Basemodel

def load_model(args, device):
    # Initialize the model
    model = Basemodel(encoder_name=args.encoder,
                      decoder_name=args.decoder,
                      head_name=args.head).to(device)

    # Load the pretrained weights
    if args.resume is not None:
        ckpt = torch.load(args.resume, weights_only=True, map_location=device)
        if 'state_dict' in ckpt:
            ckpt = ckpt['state_dict']

        model.load_state_dict(ckpt)
    else:
        print("No pretrained weights provided. Using untrained model.")

    model.eval()
    return model

def preprocess_image(image_path, device):
    # Define the image transformations
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(device)
    return image

def run_inference(model, input_image, device):
    with torch.no_grad():
        output = model(input_image)

    if isinstance(output, (tuple, list)):
        output = output[0]
    # Process the output tensor to create an image
    output = output.squeeze().cpu().numpy()
    # Normalize the output to [0, 1]
    output = (output - output.min()) / (output.max() - output.min())
    output_image = Image.fromarray((output * 255).astype(np.uint8))
    return output_image

def main():
    parser = argparse.ArgumentParser(description='NBED Edge Detection Inference')
    parser.add_argument('--input', type=str, required=True, help='Path to the input image file')
    parser.add_argument('--output', type=str, default='output.png', help='Path to save the output image')
    parser.add_argument('--resume', type=str, default=None, help='Path to the pretrained model weights (.pth file)')
    parser.add_argument("--encoder", default="Dul-M36",
                        help="Options: caformer-m36, Dul-M36")
    parser.add_argument("--decoder", default="unetp",
                        help="Options: unet, unetp, default")
    parser.add_argument("--head", default="default",
                        help="Options: default, aspp, atten, cofusion")
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load the model
    model = load_model(args, device)

    # Preprocess the input image
    input_image = preprocess_image(args.input, device)

    # Run inference
    output_image = run_inference(model, input_image, device)

    # Save the output image
    output_image.save(args.output)
    print(f'Edge-detected image saved to {args.output}')

if __name__ == '__main__':
    main()
# Run command: python3 my_inference.py --input Imgs/test_assembly.png --output outputs/edge_output.png --resume model/caformer_m36_384_in21ft1k.pth

And this is the error message I get on model loading:

RuntimeError: Error(s) in loading state_dict for Basemodel:
        Missing key(s) in state_dict: "encoder.conv1.0.weight", "encoder.conv1.0.bias", "encoder.conv2.0.weight", "encoder.conv2.0.bias", "encoder.downsample_layers.0.conv.weight", "encoder.downsample_layers.0.conv.bias", "encoder.downsample_layers.0.post_norm.weight", "encoder.downsample_layers.1.pre_norm.weight", "encoder.downsample_layers.1.conv.weight", "encoder.downsample_layers.1.conv.bias", "encoder.downsample_layers.2.pre_norm.weight", "encoder.downsample_layers.2.conv.weight", "encoder.downsample_layers.2.conv.bias", "encoder.stages.0.0.norm1.weight", "encoder.stages.0.0.token_mixer.pwconv1.weight", "encoder.stages.0.0.token_mixer.act1.scale", "encoder.stages.0.0.token_mixer.act1.bias", "encoder.stages.0.0.token_mixer.dwconv.weight", "encoder.stages.0.0.token_mixer.pwconv2.weight", "encoder.stages.0.0.norm2.weight", "encoder.stages.0.0.mlp.fc1.weight", "encoder.stages.0.0.mlp.act.scale", "encoder.stages.0.0.mlp.act.bias", "encoder.stages.0.0.mlp.fc2.weight", "encoder.stages.0.1.norm1.weight", "encoder.stages.0.1.token_mixer.pwconv1.weight", "encoder.stages.0.1.token_mixer.act1.scale", "encoder.stages.0.1.token_mixer.act1.bias", "encoder.stages.0.1.token_mixer.dwconv.weight", "encoder.stages.0.1.token_mixer.pwconv2.weight", "encoder.stages.0.1.norm2.weight", "encoder.stages.0.1.mlp.fc1.weight", "encoder.stages.0.1.mlp.act.scale", "encoder.stages.0.1.mlp.act.bias", "encoder.stages.0.1.mlp.fc2.weight", "encoder.stages.0.2.norm1.weight", "encoder.stages.0.2.token_mixer.pwconv1.weight", "encoder.stages.0.2.token_mixer.act1.scale", "encoder.stages.0.2.token_mixer.act1.bias", "encoder.stages.0.2.token_mixer.dwconv.weight", "encoder.stages.0.2.token_mixer.pwconv2.weight", "encoder.stages.0.2.norm2.weight", "encoder.stages.0.2.mlp.fc1.weight", "encoder.stages.0.2.mlp.act.scale", "encoder.stages.0.2.mlp.act.bias", "encoder.stages.0.2.mlp.fc2.weight", "encoder.stages.1.0.norm1.weight", "encoder.stages.1.0.token_mixer.pwconv1.weight", "encoder.stages.1.0.token_mixer.act1.scale", "encoder.stages.1.0.token_mixer.act1.bias", "encoder.stages.1.0.token_mixer.dwconv.weight", "encoder.stages.1.0.token_mixer.pwconv2.weight", "encoder.stages.1.0.norm2.weight", "encoder.stages.1.0.mlp.fc1.weight", "encoder.stages.1.0.mlp.act.scale", "encoder.stages.1.0.mlp.act.bias", "encoder.stages.1.0.mlp.fc2.weight", "encoder.stages.1.1.norm1.weight", "encoder.stages.1.1.token_mixer.pwconv1.weight", "encoder.stages.1.1.token_mixer.act1.scale", "encoder.stages.1.1.token_mixer.act1.bias", "encoder.stages.1.1.token_mixer.dwconv.weight", "encoder.stages.1.1.token_mixer.pwconv2.weight", "encoder.stages.1.1.norm2.weight", "encoder.stages.1.1.mlp.fc1.weight", "encoder.stages.1.1.mlp.act.scale", "encoder.stages.1.1.mlp.act.bias", "encoder.stages.1.1.mlp.fc2.weight", "encoder.stages.1.2.norm1.weight", "encoder.stages.1.2.token_mixer.pwconv1.weight", "encoder.stages.1.2.token_mixer.act1.scale", "encoder.stages.1.2.token_mixer.act1.bias", "encoder.stages.1.2.token_mixer.dwconv.weight", "encoder.stages.1.2.token_mixer.pwconv2.weight", "encoder.stages.1.2.norm2.weight", "encoder.stages.1.2.mlp.fc1.weight", "encoder.stages.1.2.mlp.act.scale", "encoder.stages.1.2.mlp.act.bias", "encoder.stages.1.2.mlp.fc2.weight", "encoder.stages.1.3.norm1.weight", "encoder.stages.1.3.token_mixer.pwconv1.weight", "encoder.stages.1.3.token_mixer.act1.scale", "encoder.stages.1.3.token_mixer.act1.bias", "encoder.stages.1.3.token_mixer.dwconv.weight", "encoder.stages.1.3.token_mixer.pwconv2.weight", "encoder.stages.1.3.norm2.weight", "encoder.stages.1.3.mlp.fc1.weight", "encoder.stages.1.3.mlp.act.scale", "encoder.stages.1.3.mlp.act.bias", "encoder.stages.1.3.mlp.fc2.weight", "encoder.stages.1.4.norm1.weight", "encoder.stages.1.4.token_mixer.pwconv1.weight", "encoder.stages.1.4.token_mixer.act1.scale", "encoder.stages.1.4.token_mixer.act1.bias", "encoder.stages.1.4.token_mixer.dwconv.weight", "encoder.stages.1.4.token_mixer.pwconv2.weight", "encoder.stages.1.4.norm2.weight", "encoder.stages.1.4.mlp.fc1.weight", "encoder.stages.1.4.mlp.act.scale", "encoder.stages.1.4.mlp.act.bias", "encoder.stages.1.4.mlp.fc2.weight", "encoder.stages.1.5.norm1.weight", "encoder.stages.1.5.token_mixer.pwconv1.weight", "encoder.stages.1.5.token_mixer.act1.scale", "encoder.stages.1.5.token_mixer.act1.bias", "encoder.stages.1.5.token_mixer.dwconv.weight", "encoder.stages.1.5.token_mixer.pwconv2.weight", "encoder.stages.1.5.norm2.weight", "encoder.stages.1.5.mlp.fc1.weight", "encoder.stages.1.5.mlp.act.scale", "encoder.stages.1.5.mlp.act.bias", "encoder.stages.1.5.mlp.fc2.weight", "encoder.stages.1.6.norm1.weight", "encoder.stages.1.6.token_mixer.pwconv1.weight", "encoder.stages.1.6.token_mixer.act1.scale", "encoder.stages.1.6.token_mixer.act1.bias", "encoder.stages.1.6.token_mixer.dwconv.weight", "encoder.stages.1.6.token_mixer.pwconv2.weight", "encoder.stages.1.6.norm2.weight", "encoder.stages.1.6.mlp.fc1.weight", "encoder.stages.1.6.mlp.act.scale", "encoder.stages.1.6.mlp.act.bias", "encoder.stages.1.6.mlp.fc2.weight", "encoder.stages.1.7.norm1.weight", "encoder.stages.1.7.token_mixer.pwconv1.weight", "encoder.stages.1.7.token_mixer.act1.scale", "encoder.stages.1.7.token_mixer.act1.bias", "encoder.stages.1.7.token_mixer.dwconv.weight", "encoder.stages.1.7.token_mixer.pwconv2.weight", "encoder.stages.1.7.norm2.weight", "encoder.stages.1.7.mlp.fc1.weight", "encoder.stages.1.7.mlp.act.scale", "encoder.stages.1.7.mlp.act.bias", "encoder.stages.1.7.mlp.fc2.weight", "encoder.stages.1.8.norm1.weight", "encoder.stages.1.8.token_mixer.pwconv1.weight", "encoder.stages.1.8.token_mixer.act1.scale", "encoder.stages.1.8.token_mixer.act1.bias", "encoder.stages.1.8.token_mixer.dwconv.weight", "encoder.stages.1.8.token_mixer.pwconv2.weight", "encoder.stages.1.8.norm2.weight", "encoder.stages.1.8.mlp.fc1.weight", "encoder.stages.1.8.mlp.act.scale", "encoder.stages.1.8.mlp.act.bias", "encoder.stages.1.8.mlp.fc2.weight", "encoder.stages.1.9.norm1.weight", "encoder.stages.1.9.token_mixer.pwconv1.weight", "encoder.stages.1.9.token_mixer.act1.scale", "encoder.stages.1.9.token_mixer.act1.bias", "encoder.stages.1.9.token_mixer.dwconv.weight", "encoder.stages.1.9.token_mixer.pwconv2.weight", "encoder.stages.1.9.norm2.weight", "encoder.stages.1.9.mlp.fc1.weight", "encoder.stages.1.9.mlp.act.scale", "encoder.stages.1.9.mlp.act.bias", "encoder.stages.1.9.mlp.fc2.weight", "encoder.stages.1.10.norm1.weight", "encoder.stages.1.10.token_mixer.pwconv1.weight", "encoder.stages.1.10.token_mixer.act1.scale", "encoder.stages.1.10.token_mixer.act1.bias", "encoder.stages.1.10.token_mixer.dwconv.weight", "encoder.stages.1.10.token_mixer.pwconv2.weight", "encoder.stages.1.10.norm2.weight", "encoder.stages.1.10.mlp.fc1.weight", "encoder.stages.1.10.mlp.act.scale", "encoder.stages.1.10.mlp.act.bias", "encoder.stages.1.10.mlp.fc2.weight", "encoder.stages.1.11.norm1.weight", "encoder.stages.1.11.token_mixer.pwconv1.weight", "encoder.stages.1.11.token_mixer.act1.scale", "encoder.stages.1.11.token_mixer.act1.bias", "encoder.stages.1.11.token_mixer.dwconv.weight", "encoder.stages.1.11.token_mixer.pwconv2.weight", "encoder.stages.1.11.norm2.weight", "encoder.stages.1.11.mlp.fc1.weight", "encoder.stages.1.11.mlp.act.scale", "encoder.stages.1.11.mlp.act.bias", "encoder.stages.1.11.mlp.fc2.weight", "encoder.stages.2.0.norm1.weight", "encoder.stages.2.0.token_mixer.qkv.weight", "encoder.stages.2.0.token_mixer.proj.weight", "encoder.stages.2.0.res_scale1.scale", "encoder.stages.2.0.norm2.weight", "encoder.stages.2.0.mlp.fc1.weight", "encoder.stages.2.0.mlp.act.scale", "encoder.stages.2.0.mlp.act.bias", "encoder.stages.2.0.mlp.fc2.weight", "encoder.stages.2.0.res_scale2.scale", "encoder.stages.2.1.norm1.weight", "encoder.stages.2.1.token_mixer.qkv.weight", "encoder.stages.2.1.token_mixer.proj.weight", "encoder.stages.2.1.res_scale1.scale", "encoder.stages.2.1.norm2.weight", "encoder.stages.2.1.mlp.fc1.weight", "encoder.stages.2.1.mlp.act.scale", "encoder.stages.2.1.mlp.act.bias", "encoder.stages.2.1.mlp.fc2.weight", "encoder.stages.2.1.res_scale2.scale", "encoder.stages.2.2.norm1.weight", "encoder.stages.2.2.token_mixer.qkv.weight", "encoder.stages.2.2.token_mixer.proj.weight", "encoder.stages.2.2.res_scale1.scale", "encoder.stages.2.2.norm2.weight", "encoder.stages.2.2.mlp.fc1.weight", "encoder.stages.2.2.mlp.act.scale", "encoder.stages.2.2.mlp.act.bias", "encoder.stages.2.2.mlp.fc2.weight", "encoder.stages.2.2.res_scale2.scale", "encoder.stages.2.3.norm1.weight", "encoder.stages.2.3.token_mixer.qkv.weight", "encoder.stages.2.3.token_mixer.proj.weight", "encoder.stages.2.3.res_scale1.scale", "encoder.stages.2.3.norm2.weight", "encoder.stages.2.3.mlp.fc1.weight", "encoder.stages.2.3.mlp.act.scale", "encoder.stages.2.3.mlp.act.bias", "encoder.stages.2.3.mlp.fc2.weight", "encoder.stages.2.3.res_scale2.scale", "encoder.stages.2.4.norm1.weight", "encoder.stages.2.4.token_mixer.qkv.weight", "encoder.stages.2.4.token_mixer.proj.weight", "encoder.stages.2.4.res_scale1.scale", "encoder.stages.2.4.norm2.weight", "encoder.stages.2.4.mlp.fc1.weight", "encoder.stages.2.4.mlp.act.scale", "encoder.stages.2.4.mlp.act.bias", "encoder.stages.2.4.mlp.fc2.weight", "encoder.stages.2.4.res_scale2.scale", "encoder.stages.2.5.norm1.weight", "encoder.stages.2.5.token_mixer.qkv.weight", "encoder.stages.2.5.token_mixer.proj.weight", "encoder.stages.2.5.res_scale1.scale", "encoder.stages.2.5.norm2.weight", "encoder.stages.2.5.mlp.fc1.weight", "encoder.stages.2.5.mlp.act.scale", "encoder.stages.2.5.mlp.act.bias", "encoder.stages.2.5.mlp.fc2.weight", "encoder.stages.2.5.res_scale2.scale", "encoder.stages.2.6.norm1.weight", "encoder.stages.2.6.token_mixer.qkv.weight", "encoder.stages.2.6.token_mixer.proj.weight", "encoder.stages.2.6.res_scale1.scale", "encoder.stages.2.6.norm2.weight", "encoder.stages.2.6.mlp.fc1.weight", "encoder.stages.2.6.mlp.act.scale", "encoder.stages.2.6.mlp.act.bias", "encoder.stages.2.6.mlp.fc2.weight", "encoder.stages.2.6.res_scale2.scale", "encoder.stages.2.7.norm1.weight", "encoder.stages.2.7.token_mixer.qkv.weight", "encoder.stages.2.7.token_mixer.proj.weight", "encoder.stages.2.7.res_scale1.scale", "encoder.stages.2.7.norm2.weight", "encoder.stages.2.7.mlp.fc1.weight", "encoder.stages.2.7.mlp.act.scale", "encoder.stages.2.7.mlp.act.bias", "encoder.stages.2.7.mlp.fc2.weight", "encoder.stages.2.7.res_scale2.scale", "encoder.stages.2.8.norm1.weight", "encoder.stages.2.8.token_mixer.qkv.weight", "encoder.stages.2.8.token_mixer.proj.weight", "encoder.stages.2.8.res_scale1.scale", "encoder.stages.2.8.norm2.weight", "encoder.stages.2.8.mlp.fc1.weight", "encoder.stages.2.8.mlp.act.scale", "encoder.stages.2.8.mlp.act.bias", "encoder.stages.2.8.mlp.fc2.weight", "encoder.stages.2.8.res_scale2.scale", "encoder.stages.2.9.norm1.weight", "encoder.stages.2.9.token_mixer.qkv.weight", "encoder.stages.2.9.token_mixer.proj.weight", "encoder.stages.2.9.res_scale1.scale", "encoder.stages.2.9.norm2.weight", "encoder.stages.2.9.mlp.fc1.weight", "encoder.stages.2.9.mlp.act.scale", "encoder.stages.2.9.mlp.act.bias", "encoder.stages.2.9.mlp.fc2.weight", "encoder.stages.2.9.res_scale2.scale", "encoder.stages.2.10.norm1.weight", "encoder.stages.2.10.token_mixer.qkv.weight", "encoder.stages.2.10.token_mixer.proj.weight", "encoder.stages.2.10.res_scale1.scale", "encoder.stages.2.10.norm2.weight", "encoder.stages.2.10.mlp.fc1.weight", "encoder.stages.2.10.mlp.act.scale", "encoder.stages.2.10.mlp.act.bias", "encoder.stages.2.10.mlp.fc2.weight", "encoder.stages.2.10.res_scale2.scale", "encoder.stages.2.11.norm1.weight", "encoder.stages.2.11.token_mixer.qkv.weight", "encoder.stages.2.11.token_mixer.proj.weight", "encoder.stages.2.11.res_scale1.scale", "encoder.stages.2.11.norm2.weight", "encoder.stages.2.11.mlp.fc1.weight", "encoder.stages.2.11.mlp.act.scale", "encoder.stages.2.11.mlp.act.bias", "encoder.stages.2.11.mlp.fc2.weight", "encoder.stages.2.11.res_scale2.scale", "encoder.stages.2.12.norm1.weight", "encoder.stages.2.12.token_mixer.qkv.weight", "encoder.stages.2.12.token_mixer.proj.weight", "encoder.stages.2.12.res_scale1.scale", "encoder.stages.2.12.norm2.weight", "encoder.stages.2.12.mlp.fc1.weight", "encoder.stages.2.12.mlp.act.scale", "encoder.stages.2.12.mlp.act.bias", "encoder.stages.2.12.mlp.fc2.weight", "encoder.stages.2.12.res_scale2.scale", "encoder.stages.2.13.norm1.weight", "encoder.stages.2.13.token_mixer.qkv.weight", "encoder.stages.2.13.token_mixer.proj.weight", "encoder.stages.2.13.res_scale1.scale", "encoder.stages.2.13.norm2.weight", "encoder.stages.2.13.mlp.fc1.weight", "encoder.stages.2.13.mlp.act.scale", "encoder.stages.2.13.mlp.act.bias", "encoder.stages.2.13.mlp.fc2.weight", "encoder.stages.2.13.res_scale2.scale", "encoder.stages.2.14.norm1.weight", "encoder.stages.2.14.token_mixer.qkv.weight", "encoder.stages.2.14.token_mixer.proj.weight", "encoder.stages.2.14.res_scale1.scale", "encoder.stages.2.14.norm2.weight", "encoder.stages.2.14.mlp.fc1.weight", "encoder.stages.2.14.mlp.act.scale", "encoder.stages.2.14.mlp.act.bias", "encoder.stages.2.14.mlp.fc2.weight", "encoder.stages.2.14.res_scale2.scale", "encoder.stages.2.15.norm1.weight", "encoder.stages.2.15.token_mixer.qkv.weight", "encoder.stages.2.15.token_mixer.proj.weight", "encoder.stages.2.15.res_scale1.scale", "encoder.stages.2.15.norm2.weight", "encoder.stages.2.15.mlp.fc1.weight", "encoder.stages.2.15.mlp.act.scale", "encoder.stages.2.15.mlp.act.bias", "encoder.stages.2.15.mlp.fc2.weight", "encoder.stages.2.15.res_scale2.scale", "encoder.stages.2.16.norm1.weight", "encoder.stages.2.16.token_mixer.qkv.weight", "encoder.stages.2.16.token_mixer.proj.weight", "encoder.stages.2.16.res_scale1.scale", "encoder.stages.2.16.norm2.weight", "encoder.stages.2.16.mlp.fc1.weight", "encoder.stages.2.16.mlp.act.scale", "encoder.stages.2.16.mlp.act.bias", "encoder.stages.2.16.mlp.fc2.weight", "encoder.stages.2.16.res_scale2.scale", "encoder.stages.2.17.norm1.weight", "encoder.stages.2.17.token_mixer.qkv.weight", "encoder.stages.2.17.token_mixer.proj.weight", "encoder.stages.2.17.res_scale1.scale", "encoder.stages.2.17.norm2.weight", "encoder.stages.2.17.mlp.fc1.weight", "encoder.stages.2.17.mlp.act.scale", "encoder.stages.2.17.mlp.act.bias", "encoder.stages.2.17.mlp.fc2.weight", "encoder.stages.2.17.res_scale2.scale", "decoder.convs.conv0_0.conv1.0.weight", "decoder.convs.conv0_0.conv2.0.weight", "decoder.convs.conv1_0.conv1.0.weight", "decoder.convs.conv1_0.conv2.0.weight", "decoder.convs.conv2_0.conv1.0.weight", "decoder.convs.conv2_0.conv2.0.weight", "decoder.convs.conv3_0.conv1.0.weight", "decoder.convs.conv3_0.conv2.0.weight", "decoder.convs.conv0_1.conv1.0.weight", "decoder.convs.conv0_1.conv2.0.weight", "decoder.convs.conv1_1.conv1.0.weight", "decoder.convs.conv1_1.conv2.0.weight", "decoder.convs.conv2_1.conv1.0.weight", "decoder.convs.conv2_1.conv2.0.weight", "decoder.convs.conv0_2.conv1.0.weight", "decoder.convs.conv0_2.conv2.0.weight", "decoder.convs.conv1_2.conv1.0.weight", "decoder.convs.conv1_2.conv2.0.weight", "decoder.convs.conv0_3.conv1.0.weight", "decoder.convs.conv0_3.conv2.0.weight", "head.final.0.weight", "head.final.0.bias". 
        Unexpected key(s) in state_dict: "downsample_layers.0.conv.weight", "downsample_layers.0.conv.bias", "downsample_layers.0.post_norm.weight", "downsample_layers.1.pre_norm.weight", "downsample_layers.1.conv.weight", "downsample_layers.1.conv.bias", "downsample_layers.2.pre_norm.weight", "downsample_layers.2.conv.weight", "downsample_layers.2.conv.bias", "downsample_layers.3.pre_norm.weight", "downsample_layers.3.conv.weight", "downsample_layers.3.conv.bias", "stages.0.0.norm1.weight", "stages.0.0.token_mixer.pwconv1.weight", "stages.0.0.token_mixer.act1.scale", "stages.0.0.token_mixer.act1.bias", "stages.0.0.token_mixer.dwconv.weight", "stages.0.0.token_mixer.pwconv2.weight", "stages.0.0.norm2.weight", "stages.0.0.mlp.fc1.weight", "stages.0.0.mlp.act.scale", "stages.0.0.mlp.act.bias", "stages.0.0.mlp.fc2.weight", "stages.0.1.norm1.weight", "stages.0.1.token_mixer.pwconv1.weight", "stages.0.1.token_mixer.act1.scale", "stages.0.1.token_mixer.act1.bias", "stages.0.1.token_mixer.dwconv.weight", "stages.0.1.token_mixer.pwconv2.weight", "stages.0.1.norm2.weight", "stages.0.1.mlp.fc1.weight", "stages.0.1.mlp.act.scale", "stages.0.1.mlp.act.bias", "stages.0.1.mlp.fc2.weight", "stages.0.2.norm1.weight", "stages.0.2.token_mixer.pwconv1.weight", "stages.0.2.token_mixer.act1.scale", "stages.0.2.token_mixer.act1.bias", "stages.0.2.token_mixer.dwconv.weight", "stages.0.2.token_mixer.pwconv2.weight", "stages.0.2.norm2.weight", "stages.0.2.mlp.fc1.weight", "stages.0.2.mlp.act.scale", "stages.0.2.mlp.act.bias", "stages.0.2.mlp.fc2.weight", "stages.1.0.norm1.weight", "stages.1.0.token_mixer.pwconv1.weight", "stages.1.0.token_mixer.act1.scale", "stages.1.0.token_mixer.act1.bias", "stages.1.0.token_mixer.dwconv.weight", "stages.1.0.token_mixer.pwconv2.weight", "stages.1.0.norm2.weight", "stages.1.0.mlp.fc1.weight", "stages.1.0.mlp.act.scale", "stages.1.0.mlp.act.bias", "stages.1.0.mlp.fc2.weight", "stages.1.1.norm1.weight", "stages.1.1.token_mixer.pwconv1.weight", "stages.1.1.token_mixer.act1.scale", "stages.1.1.token_mixer.act1.bias", "stages.1.1.token_mixer.dwconv.weight", "stages.1.1.token_mixer.pwconv2.weight", "stages.1.1.norm2.weight", "stages.1.1.mlp.fc1.weight", "stages.1.1.mlp.act.scale", "stages.1.1.mlp.act.bias", "stages.1.1.mlp.fc2.weight", "stages.1.2.norm1.weight", "stages.1.2.token_mixer.pwconv1.weight", "stages.1.2.token_mixer.act1.scale", "stages.1.2.token_mixer.act1.bias", "stages.1.2.token_mixer.dwconv.weight", "stages.1.2.token_mixer.pwconv2.weight", "stages.1.2.norm2.weight", "stages.1.2.mlp.fc1.weight", "stages.1.2.mlp.act.scale", "stages.1.2.mlp.act.bias", "stages.1.2.mlp.fc2.weight", "stages.1.3.norm1.weight", "stages.1.3.token_mixer.pwconv1.weight", "stages.1.3.token_mixer.act1.scale", "stages.1.3.token_mixer.act1.bias", "stages.1.3.token_mixer.dwconv.weight", "stages.1.3.token_mixer.pwconv2.weight", "stages.1.3.norm2.weight", "stages.1.3.mlp.fc1.weight", "stages.1.3.mlp.act.scale", "stages.1.3.mlp.act.bias", "stages.1.3.mlp.fc2.weight", "stages.1.4.norm1.weight", "stages.1.4.token_mixer.pwconv1.weight", "stages.1.4.token_mixer.act1.scale", "stages.1.4.token_mixer.act1.bias", "stages.1.4.token_mixer.dwconv.weight", "stages.1.4.token_mixer.pwconv2.weight", "stages.1.4.norm2.weight", "stages.1.4.mlp.fc1.weight", "stages.1.4.mlp.act.scale", "stages.1.4.mlp.act.bias", "stages.1.4.mlp.fc2.weight", "stages.1.5.norm1.weight", "stages.1.5.token_mixer.pwconv1.weight", "stages.1.5.token_mixer.act1.scale", "stages.1.5.token_mixer.act1.bias", "stages.1.5.token_mixer.dwconv.weight", "stages.1.5.token_mixer.pwconv2.weight", "stages.1.5.norm2.weight", "stages.1.5.mlp.fc1.weight", "stages.1.5.mlp.act.scale", "stages.1.5.mlp.act.bias", "stages.1.5.mlp.fc2.weight", "stages.1.6.norm1.weight", "stages.1.6.token_mixer.pwconv1.weight", "stages.1.6.token_mixer.act1.scale", "stages.1.6.token_mixer.act1.bias", "stages.1.6.token_mixer.dwconv.weight", "stages.1.6.token_mixer.pwconv2.weight", "stages.1.6.norm2.weight", "stages.1.6.mlp.fc1.weight", "stages.1.6.mlp.act.scale", "stages.1.6.mlp.act.bias", "stages.1.6.mlp.fc2.weight", "stages.1.7.norm1.weight", "stages.1.7.token_mixer.pwconv1.weight", "stages.1.7.token_mixer.act1.scale", "stages.1.7.token_mixer.act1.bias", "stages.1.7.token_mixer.dwconv.weight", "stages.1.7.token_mixer.pwconv2.weight", "stages.1.7.norm2.weight", "stages.1.7.mlp.fc1.weight", "stages.1.7.mlp.act.scale", "stages.1.7.mlp.act.bias", "stages.1.7.mlp.fc2.weight", "stages.1.8.norm1.weight", "stages.1.8.token_mixer.pwconv1.weight", "stages.1.8.token_mixer.act1.scale", "stages.1.8.token_mixer.act1.bias", "stages.1.8.token_mixer.dwconv.weight", "stages.1.8.token_mixer.pwconv2.weight", "stages.1.8.norm2.weight", "stages.1.8.mlp.fc1.weight", "stages.1.8.mlp.act.scale", "stages.1.8.mlp.act.bias", "stages.1.8.mlp.fc2.weight", "stages.1.9.norm1.weight", "stages.1.9.token_mixer.pwconv1.weight", "stages.1.9.token_mixer.act1.scale", "stages.1.9.token_mixer.act1.bias", "stages.1.9.token_mixer.dwconv.weight", "stages.1.9.token_mixer.pwconv2.weight", "stages.1.9.norm2.weight", "stages.1.9.mlp.fc1.weight", "stages.1.9.mlp.act.scale", "stages.1.9.mlp.act.bias", "stages.1.9.mlp.fc2.weight", "stages.1.10.norm1.weight", "stages.1.10.token_mixer.pwconv1.weight", "stages.1.10.token_mixer.act1.scale", "stages.1.10.token_mixer.act1.bias", "stages.1.10.token_mixer.dwconv.weight", "stages.1.10.token_mixer.pwconv2.weight", "stages.1.10.norm2.weight", "stages.1.10.mlp.fc1.weight", "stages.1.10.mlp.act.scale", "stages.1.10.mlp.act.bias", "stages.1.10.mlp.fc2.weight", "stages.1.11.norm1.weight", "stages.1.11.token_mixer.pwconv1.weight", "stages.1.11.token_mixer.act1.scale", "stages.1.11.token_mixer.act1.bias", "stages.1.11.token_mixer.dwconv.weight", "stages.1.11.token_mixer.pwconv2.weight", "stages.1.11.norm2.weight", "stages.1.11.mlp.fc1.weight", "stages.1.11.mlp.act.scale", "stages.1.11.mlp.act.bias", "stages.1.11.mlp.fc2.weight", "stages.2.0.norm1.weight", "stages.2.0.token_mixer.qkv.weight", "stages.2.0.token_mixer.proj.weight", "stages.2.0.res_scale1.scale", "stages.2.0.norm2.weight", "stages.2.0.mlp.fc1.weight", "stages.2.0.mlp.act.scale", "stages.2.0.mlp.act.bias", "stages.2.0.mlp.fc2.weight", "stages.2.0.res_scale2.scale", "stages.2.1.norm1.weight", "stages.2.1.token_mixer.qkv.weight", "stages.2.1.token_mixer.proj.weight", "stages.2.1.res_scale1.scale", "stages.2.1.norm2.weight", "stages.2.1.mlp.fc1.weight", "stages.2.1.mlp.act.scale", "stages.2.1.mlp.act.bias", "stages.2.1.mlp.fc2.weight", "stages.2.1.res_scale2.scale", "stages.2.2.norm1.weight", "stages.2.2.token_mixer.qkv.weight", "stages.2.2.token_mixer.proj.weight", "stages.2.2.res_scale1.scale", "stages.2.2.norm2.weight", "stages.2.2.mlp.fc1.weight", "stages.2.2.mlp.act.scale", "stages.2.2.mlp.act.bias", "stages.2.2.mlp.fc2.weight", "stages.2.2.res_scale2.scale", "stages.2.3.norm1.weight", "stages.2.3.token_mixer.qkv.weight", "stages.2.3.token_mixer.proj.weight", "stages.2.3.res_scale1.scale", "stages.2.3.norm2.weight", "stages.2.3.mlp.fc1.weight", "stages.2.3.mlp.act.scale", "stages.2.3.mlp.act.bias", "stages.2.3.mlp.fc2.weight", "stages.2.3.res_scale2.scale", "stages.2.4.norm1.weight", "stages.2.4.token_mixer.qkv.weight", "stages.2.4.token_mixer.proj.weight", "stages.2.4.res_scale1.scale", "stages.2.4.norm2.weight", "stages.2.4.mlp.fc1.weight", "stages.2.4.mlp.act.scale", "stages.2.4.mlp.act.bias", "stages.2.4.mlp.fc2.weight", "stages.2.4.res_scale2.scale", "stages.2.5.norm1.weight", "stages.2.5.token_mixer.qkv.weight", "stages.2.5.token_mixer.proj.weight", "stages.2.5.res_scale1.scale", "stages.2.5.norm2.weight", "stages.2.5.mlp.fc1.weight", "stages.2.5.mlp.act.scale", "stages.2.5.mlp.act.bias", "stages.2.5.mlp.fc2.weight", "stages.2.5.res_scale2.scale", "stages.2.6.norm1.weight", "stages.2.6.token_mixer.qkv.weight", "stages.2.6.token_mixer.proj.weight", "stages.2.6.res_scale1.scale", "stages.2.6.norm2.weight", "stages.2.6.mlp.fc1.weight", "stages.2.6.mlp.act.scale", "stages.2.6.mlp.act.bias", "stages.2.6.mlp.fc2.weight", "stages.2.6.res_scale2.scale", "stages.2.7.norm1.weight", "stages.2.7.token_mixer.qkv.weight", "stages.2.7.token_mixer.proj.weight", "stages.2.7.res_scale1.scale", "stages.2.7.norm2.weight", "stages.2.7.mlp.fc1.weight", "stages.2.7.mlp.act.scale", "stages.2.7.mlp.act.bias", "stages.2.7.mlp.fc2.weight", "stages.2.7.res_scale2.scale", "stages.2.8.norm1.weight", "stages.2.8.token_mixer.qkv.weight", "stages.2.8.token_mixer.proj.weight", "stages.2.8.res_scale1.scale", "stages.2.8.norm2.weight", "stages.2.8.mlp.fc1.weight", "stages.2.8.mlp.act.scale", "stages.2.8.mlp.act.bias", "stages.2.8.mlp.fc2.weight", "stages.2.8.res_scale2.scale", "stages.2.9.norm1.weight", "stages.2.9.token_mixer.qkv.weight", "stages.2.9.token_mixer.proj.weight", "stages.2.9.res_scale1.scale", "stages.2.9.norm2.weight", "stages.2.9.mlp.fc1.weight", "stages.2.9.mlp.act.scale", "stages.2.9.mlp.act.bias", "stages.2.9.mlp.fc2.weight", "stages.2.9.res_scale2.scale", "stages.2.10.norm1.weight", "stages.2.10.token_mixer.qkv.weight", "stages.2.10.token_mixer.proj.weight", "stages.2.10.res_scale1.scale", "stages.2.10.norm2.weight", "stages.2.10.mlp.fc1.weight", "stages.2.10.mlp.act.scale", "stages.2.10.mlp.act.bias", "stages.2.10.mlp.fc2.weight", "stages.2.10.res_scale2.scale", "stages.2.11.norm1.weight", "stages.2.11.token_mixer.qkv.weight", "stages.2.11.token_mixer.proj.weight", "stages.2.11.res_scale1.scale", "stages.2.11.norm2.weight", "stages.2.11.mlp.fc1.weight", "stages.2.11.mlp.act.scale", "stages.2.11.mlp.act.bias", "stages.2.11.mlp.fc2.weight", "stages.2.11.res_scale2.scale", "stages.2.12.norm1.weight", "stages.2.12.token_mixer.qkv.weight", "stages.2.12.token_mixer.proj.weight", "stages.2.12.res_scale1.scale", "stages.2.12.norm2.weight", "stages.2.12.mlp.fc1.weight", "stages.2.12.mlp.act.scale", "stages.2.12.mlp.act.bias", "stages.2.12.mlp.fc2.weight", "stages.2.12.res_scale2.scale", "stages.2.13.norm1.weight", "stages.2.13.token_mixer.qkv.weight", "stages.2.13.token_mixer.proj.weight", "stages.2.13.res_scale1.scale", "stages.2.13.norm2.weight", "stages.2.13.mlp.fc1.weight", "stages.2.13.mlp.act.scale", "stages.2.13.mlp.act.bias", "stages.2.13.mlp.fc2.weight", "stages.2.13.res_scale2.scale", "stages.2.14.norm1.weight", "stages.2.14.token_mixer.qkv.weight", "stages.2.14.token_mixer.proj.weight", "stages.2.14.res_scale1.scale", "stages.2.14.norm2.weight", "stages.2.14.mlp.fc1.weight", "stages.2.14.mlp.act.scale", "stages.2.14.mlp.act.bias", "stages.2.14.mlp.fc2.weight", "stages.2.14.res_scale2.scale", "stages.2.15.norm1.weight", "stages.2.15.token_mixer.qkv.weight", "stages.2.15.token_mixer.proj.weight", "stages.2.15.res_scale1.scale", "stages.2.15.norm2.weight", "stages.2.15.mlp.fc1.weight", "stages.2.15.mlp.act.scale", "stages.2.15.mlp.act.bias", "stages.2.15.mlp.fc2.weight", "stages.2.15.res_scale2.scale", "stages.2.16.norm1.weight", "stages.2.16.token_mixer.qkv.weight", "stages.2.16.token_mixer.proj.weight", "stages.2.16.res_scale1.scale", "stages.2.16.norm2.weight", "stages.2.16.mlp.fc1.weight", "stages.2.16.mlp.act.scale", "stages.2.16.mlp.act.bias", "stages.2.16.mlp.fc2.weight", "stages.2.16.res_scale2.scale", "stages.2.17.norm1.weight", "stages.2.17.token_mixer.qkv.weight", "stages.2.17.token_mixer.proj.weight", "stages.2.17.res_scale1.scale", "stages.2.17.norm2.weight", "stages.2.17.mlp.fc1.weight", "stages.2.17.mlp.act.scale", "stages.2.17.mlp.act.bias", "stages.2.17.mlp.fc2.weight", "stages.2.17.res_scale2.scale", "stages.3.0.norm1.weight", "stages.3.0.token_mixer.qkv.weight", "stages.3.0.token_mixer.proj.weight", "stages.3.0.res_scale1.scale", "stages.3.0.norm2.weight", "stages.3.0.mlp.fc1.weight", "stages.3.0.mlp.act.scale", "stages.3.0.mlp.act.bias", "stages.3.0.mlp.fc2.weight", "stages.3.0.res_scale2.scale", "stages.3.1.norm1.weight", "stages.3.1.token_mixer.qkv.weight", "stages.3.1.token_mixer.proj.weight", "stages.3.1.res_scale1.scale", "stages.3.1.norm2.weight", "stages.3.1.mlp.fc1.weight", "stages.3.1.mlp.act.scale", "stages.3.1.mlp.act.bias", "stages.3.1.mlp.fc2.weight", "stages.3.1.res_scale2.scale", "stages.3.2.norm1.weight", "stages.3.2.token_mixer.qkv.weight", "stages.3.2.token_mixer.proj.weight", "stages.3.2.res_scale1.scale", "stages.3.2.norm2.weight", "stages.3.2.mlp.fc1.weight", "stages.3.2.mlp.act.scale", "stages.3.2.mlp.act.bias", "stages.3.2.mlp.fc2.weight", "stages.3.2.res_scale2.scale", "norm.weight", "norm.bias", "head.fc1.weight", "head.fc1.bias", "head.norm.weight", "head.norm.bias", "head.fc2.weight", "head.fc2.bias". 

This error is related to the command model.load_state_dict(ckpt).

Before running this script I downloaded the checkpoint you mentioned in README: https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth

I have seen that in main.py script you load the model using checkpoint like this:

if args.resume is not None:
        ckpt = torch.load(args.resume, map_location='cpu')['state_dict']
        ckpt["encoder.conv2.0.weight"] = ckpt.pop("encoder.conv2.1.weight")
        ckpt["encoder.conv2.0.bias"] = ckpt.pop("encoder.conv2.1.bias")
        model.load_state_dict(ckpt)
        print("load pretrained model, successfully!")

However, I cannot load the model this way as I'm getting error saying that there is no key 'state_dict' for ckpt object. Also, I cannot pop "encoder.conv2.1.weight" and "encoder.conv2.1.bias" since there are no such keys in ckpt object.

Here are my python libs versions inside virtual environment:

Ubuntu 22.04

Python 3.10.12

Package                  Version
------------------------ ----------
certifi                  2024.8.30
charset-normalizer       3.4.0
filelock                 3.16.1
fsspec                   2024.10.0
huggingface-hub          0.26.0
idna                     3.10
Jinja2                   3.1.4
MarkupSafe               3.0.2
mpmath                   1.3.0
networkx                 3.4.1
numpy                    2.1.2
nvidia-cublas-cu12       12.4.5.8
nvidia-cuda-cupti-cu12   12.4.127
nvidia-cuda-nvrtc-cu12   12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.2.1.3
nvidia-curand-cu12       10.3.5.147
nvidia-cusolver-cu12     11.6.1.9
nvidia-cusparse-cu12     12.3.1.170
nvidia-nccl-cu12         2.21.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.4.127
packaging                24.1
pillow                   11.0.0
pip                      22.0.2
PyYAML                   6.0.2
requests                 2.32.3
safetensors              0.4.5
setuptools               59.6.0
sympy                    1.13.1
timm                     1.0.11
torch                    2.5.0
torchvision              0.20.0
tqdm                     4.66.5
triton                   3.1.0
typing_extensions        4.12.2
urllib3                  2.2.3

Could you please provide us with a separate inference script that we can easily run on an arbitrary input image using the pretrained weights you mentioned?

Li-yachuan commented 1 month ago

There is no problem with the code, because other people have successfully reproduced it according to the code I provided. I am currently on a week-long trip and my computer is not with me. You can try to solve the problem first. If it is not solved, I can get back to you next week.

---Original--- From: "Nikola @.> Date: Mon, Oct 21, 2024 11:41 AM To: @.>; Cc: @.**@.>; Subject: Re: [Li-yachuan/NBED] Inference script (Issue #3)

Here is my_inference.py script: import torch import torchvision.transforms as transforms from PIL import Image import numpy as np import argparse import os # Import the model definition from model.basemodel import Basemodel def load_model(args, device): # Initialize the model model = Basemodel(encoder_name=args.encoder, decoder_name=args.decoder, head_name=args.head).to(device) # Load the pretrained weights if args.resume is not None: ckpt = torch.load(args.resume, weights_only=True, map_location=device) if 'state_dict' in ckpt: ckpt = ckpt['state_dict'] model.load_state_dict(ckpt) else: print("No pretrained weights provided. Using untrained model.") model.eval() return model def preprocess_image(image_path, device): # Define the image transformations transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load and preprocess the image image = Image.open(image_path).convert('RGB') image = transform(image).unsqueeze(0) # Add batch dimension image = image.to(device) return image def run_inference(model, input_image, device): with torch.no_grad(): output = model(input_image) if isinstance(output, (tuple, list)): output = output[0] # Process the output tensor to create an image output = output.squeeze().cpu().numpy() # Normalize the output to [0, 1] output = (output - output.min()) / (output.max() - output.min()) output_image = Image.fromarray((output * 255).astype(np.uint8)) return output_image def main(): parser = argparse.ArgumentParser(description='NBED Edge Detection Inference') parser.add_argument('--input', type=str, required=True, help='Path to the input image file') parser.add_argument('--output', type=str, default='output.png', help='Path to save the output image') parser.add_argument('--resume', type=str, default=None, help='Path to the pretrained model weights (.pth file)') parser.add_argument("--encoder", default="Dul-M36", help="Options: caformer-m36, Dul-M36") parser.add_argument("--decoder", default="unetp", help="Options: unet, unetp, default") parser.add_argument("--head", default="default", help="Options: default, aspp, atten, cofusion") args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the model model = load_model(args, device) # Preprocess the input image input_image = preprocess_image(args.input, device) # Run inference output_image = run_inference(model, input_image, device) # Save the output image output_image.save(args.output) print(f'Edge-detected image saved to {args.output}') if name == 'main': main() # Run command: python3 my_inference.py --input Imgs/test_assembly.png --output outputs/edge_output.png --resume model/caformer_m36_384_in21ft1k.pth
And this is the error message I get on model loading: RuntimeError: Error(s) in loading state_dict for Basemodel: Missing key(s) in state_dict: "encoder.conv1.0.weight", "encoder.conv1.0.bias", "encoder.conv2.0.weight", "encoder.conv2.0.bias", "encoder.downsample_layers.0.conv.weight", "encoder.downsample_layers.0.conv.bias", "encoder.downsample_layers.0.post_norm.weight", "encoder.downsample_layers.1.pre_norm.weight", "encoder.downsample_layers.1.conv.weight", "encoder.downsample_layers.1.conv.bias", "encoder.downsample_layers.2.pre_norm.weight", "encoder.downsample_layers.2.conv.weight", "encoder.downsample_layers.2.conv.bias", "encoder.stages.0.0.norm1.weight", "encoder.stages.0.0.token_mixer.pwconv1.weight", "encoder.stages.0.0.token_mixer.act1.scale", "encoder.stages.0.0.token_mixer.act1.bias", "encoder.stages.0.0.token_mixer.dwconv.weight", "encoder.stages.0.0.token_mixer.pwconv2.weight", "encoder.stages.0.0.norm2.weight", "encoder.stages.0.0.mlp.fc1.weight", "encoder.stages.0.0.mlp.act.scale", "encoder.stages.0.0.mlp.act.bias", "encoder.stages.0.0.mlp.fc2.weight", "encoder.stages.0.1.norm1.weight", "encoder.stages.0.1.token_mixer.pwconv1.weight", "encoder.stages.0.1.token_mixer.act1.scale", "encoder.stages.0.1.token_mixer.act1.bias", "encoder.stages.0.1.token_mixer.dwconv.weight", "encoder.stages.0.1.token_mixer.pwconv2.weight", "encoder.stages.0.1.norm2.weight", "encoder.stages.0.1.mlp.fc1.weight", "encoder.stages.0.1.mlp.act.scale", "encoder.stages.0.1.mlp.act.bias", "encoder.stages.0.1.mlp.fc2.weight", "encoder.stages.0.2.norm1.weight", "encoder.stages.0.2.token_mixer.pwconv1.weight", "encoder.stages.0.2.token_mixer.act1.scale", "encoder.stages.0.2.token_mixer.act1.bias", "encoder.stages.0.2.token_mixer.dwconv.weight", "encoder.stages.0.2.token_mixer.pwconv2.weight", "encoder.stages.0.2.norm2.weight", "encoder.stages.0.2.mlp.fc1.weight", "encoder.stages.0.2.mlp.act.scale", "encoder.stages.0.2.mlp.act.bias", "encoder.stages.0.2.mlp.fc2.weight", "encoder.stages.1.0.norm1.weight", "encoder.stages.1.0.token_mixer.pwconv1.weight", "encoder.stages.1.0.token_mixer.act1.scale", "encoder.stages.1.0.token_mixer.act1.bias", "encoder.stages.1.0.token_mixer.dwconv.weight", "encoder.stages.1.0.token_mixer.pwconv2.weight", "encoder.stages.1.0.norm2.weight", "encoder.stages.1.0.mlp.fc1.weight", "encoder.stages.1.0.mlp.act.scale", "encoder.stages.1.0.mlp.act.bias", "encoder.stages.1.0.mlp.fc2.weight", "encoder.stages.1.1.norm1.weight", "encoder.stages.1.1.token_mixer.pwconv1.weight", "encoder.stages.1.1.token_mixer.act1.scale", "encoder.stages.1.1.token_mixer.act1.bias", "encoder.stages.1.1.token_mixer.dwconv.weight", "encoder.stages.1.1.token_mixer.pwconv2.weight", "encoder.stages.1.1.norm2.weight", "encoder.stages.1.1.mlp.fc1.weight", "encoder.stages.1.1.mlp.act.scale", "encoder.stages.1.1.mlp.act.bias", "encoder.stages.1.1.mlp.fc2.weight", "encoder.stages.1.2.norm1.weight", "encoder.stages.1.2.token_mixer.pwconv1.weight", "encoder.stages.1.2.token_mixer.act1.scale", "encoder.stages.1.2.token_mixer.act1.bias", "encoder.stages.1.2.token_mixer.dwconv.weight", "encoder.stages.1.2.token_mixer.pwconv2.weight", "encoder.stages.1.2.norm2.weight", "encoder.stages.1.2.mlp.fc1.weight", "encoder.stages.1.2.mlp.act.scale", "encoder.stages.1.2.mlp.act.bias", "encoder.stages.1.2.mlp.fc2.weight", "encoder.stages.1.3.norm1.weight", "encoder.stages.1.3.token_mixer.pwconv1.weight", "encoder.stages.1.3.token_mixer.act1.scale", "encoder.stages.1.3.token_mixer.act1.bias", "encoder.stages.1.3.token_mixer.dwconv.weight", "encoder.stages.1.3.token_mixer.pwconv2.weight", "encoder.stages.1.3.norm2.weight", "encoder.stages.1.3.mlp.fc1.weight", "encoder.stages.1.3.mlp.act.scale", "encoder.stages.1.3.mlp.act.bias", "encoder.stages.1.3.mlp.fc2.weight", "encoder.stages.1.4.norm1.weight", "encoder.stages.1.4.token_mixer.pwconv1.weight", "encoder.stages.1.4.token_mixer.act1.scale", "encoder.stages.1.4.token_mixer.act1.bias", "encoder.stages.1.4.token_mixer.dwconv.weight", "encoder.stages.1.4.token_mixer.pwconv2.weight", "encoder.stages.1.4.norm2.weight", "encoder.stages.1.4.mlp.fc1.weight", "encoder.stages.1.4.mlp.act.scale", "encoder.stages.1.4.mlp.act.bias", "encoder.stages.1.4.mlp.fc2.weight", "encoder.stages.1.5.norm1.weight", "encoder.stages.1.5.token_mixer.pwconv1.weight", "encoder.stages.1.5.token_mixer.act1.scale", "encoder.stages.1.5.token_mixer.act1.bias", "encoder.stages.1.5.token_mixer.dwconv.weight", "encoder.stages.1.5.token_mixer.pwconv2.weight", "encoder.stages.1.5.norm2.weight", "encoder.stages.1.5.mlp.fc1.weight", "encoder.stages.1.5.mlp.act.scale", "encoder.stages.1.5.mlp.act.bias", "encoder.stages.1.5.mlp.fc2.weight", "encoder.stages.1.6.norm1.weight", "encoder.stages.1.6.token_mixer.pwconv1.weight", "encoder.stages.1.6.token_mixer.act1.scale", "encoder.stages.1.6.token_mixer.act1.bias", "encoder.stages.1.6.token_mixer.dwconv.weight", "encoder.stages.1.6.token_mixer.pwconv2.weight", "encoder.stages.1.6.norm2.weight", "encoder.stages.1.6.mlp.fc1.weight", "encoder.stages.1.6.mlp.act.scale", "encoder.stages.1.6.mlp.act.bias", "encoder.stages.1.6.mlp.fc2.weight", "encoder.stages.1.7.norm1.weight", "encoder.stages.1.7.token_mixer.pwconv1.weight", "encoder.stages.1.7.token_mixer.act1.scale", "encoder.stages.1.7.token_mixer.act1.bias", "encoder.stages.1.7.token_mixer.dwconv.weight", "encoder.stages.1.7.token_mixer.pwconv2.weight", "encoder.stages.1.7.norm2.weight", "encoder.stages.1.7.mlp.fc1.weight", "encoder.stages.1.7.mlp.act.scale", "encoder.stages.1.7.mlp.act.bias", "encoder.stages.1.7.mlp.fc2.weight", "encoder.stages.1.8.norm1.weight", "encoder.stages.1.8.token_mixer.pwconv1.weight", "encoder.stages.1.8.token_mixer.act1.scale", "encoder.stages.1.8.token_mixer.act1.bias", "encoder.stages.1.8.token_mixer.dwconv.weight", "encoder.stages.1.8.token_mixer.pwconv2.weight", "encoder.stages.1.8.norm2.weight", "encoder.stages.1.8.mlp.fc1.weight", "encoder.stages.1.8.mlp.act.scale", "encoder.stages.1.8.mlp.act.bias", "encoder.stages.1.8.mlp.fc2.weight", "encoder.stages.1.9.norm1.weight", "encoder.stages.1.9.token_mixer.pwconv1.weight", "encoder.stages.1.9.token_mixer.act1.scale", "encoder.stages.1.9.token_mixer.act1.bias", "encoder.stages.1.9.token_mixer.dwconv.weight", "encoder.stages.1.9.token_mixer.pwconv2.weight", "encoder.stages.1.9.norm2.weight", "encoder.stages.1.9.mlp.fc1.weight", "encoder.stages.1.9.mlp.act.scale", "encoder.stages.1.9.mlp.act.bias", "encoder.stages.1.9.mlp.fc2.weight", "encoder.stages.1.10.norm1.weight", "encoder.stages.1.10.token_mixer.pwconv1.weight", "encoder.stages.1.10.token_mixer.act1.scale", "encoder.stages.1.10.token_mixer.act1.bias", "encoder.stages.1.10.token_mixer.dwconv.weight", "encoder.stages.1.10.token_mixer.pwconv2.weight", "encoder.stages.1.10.norm2.weight", "encoder.stages.1.10.mlp.fc1.weight", "encoder.stages.1.10.mlp.act.scale", "encoder.stages.1.10.mlp.act.bias", "encoder.stages.1.10.mlp.fc2.weight", "encoder.stages.1.11.norm1.weight", "encoder.stages.1.11.token_mixer.pwconv1.weight", "encoder.stages.1.11.token_mixer.act1.scale", "encoder.stages.1.11.token_mixer.act1.bias", "encoder.stages.1.11.token_mixer.dwconv.weight", "encoder.stages.1.11.token_mixer.pwconv2.weight", "encoder.stages.1.11.norm2.weight", "encoder.stages.1.11.mlp.fc1.weight", "encoder.stages.1.11.mlp.act.scale", "encoder.stages.1.11.mlp.act.bias", "encoder.stages.1.11.mlp.fc2.weight", "encoder.stages.2.0.norm1.weight", "encoder.stages.2.0.token_mixer.qkv.weight", "encoder.stages.2.0.token_mixer.proj.weight", "encoder.stages.2.0.res_scale1.scale", "encoder.stages.2.0.norm2.weight", "encoder.stages.2.0.mlp.fc1.weight", "encoder.stages.2.0.mlp.act.scale", "encoder.stages.2.0.mlp.act.bias", "encoder.stages.2.0.mlp.fc2.weight", "encoder.stages.2.0.res_scale2.scale", "encoder.stages.2.1.norm1.weight", "encoder.stages.2.1.token_mixer.qkv.weight", "encoder.stages.2.1.token_mixer.proj.weight", "encoder.stages.2.1.res_scale1.scale", "encoder.stages.2.1.norm2.weight", "encoder.stages.2.1.mlp.fc1.weight", "encoder.stages.2.1.mlp.act.scale", "encoder.stages.2.1.mlp.act.bias", "encoder.stages.2.1.mlp.fc2.weight", "encoder.stages.2.1.res_scale2.scale", "encoder.stages.2.2.norm1.weight", "encoder.stages.2.2.token_mixer.qkv.weight", "encoder.stages.2.2.token_mixer.proj.weight", "encoder.stages.2.2.res_scale1.scale", "encoder.stages.2.2.norm2.weight", "encoder.stages.2.2.mlp.fc1.weight", "encoder.stages.2.2.mlp.act.scale", "encoder.stages.2.2.mlp.act.bias", "encoder.stages.2.2.mlp.fc2.weight", "encoder.stages.2.2.res_scale2.scale", "encoder.stages.2.3.norm1.weight", "encoder.stages.2.3.token_mixer.qkv.weight", "encoder.stages.2.3.token_mixer.proj.weight", "encoder.stages.2.3.res_scale1.scale", "encoder.stages.2.3.norm2.weight", "encoder.stages.2.3.mlp.fc1.weight", "encoder.stages.2.3.mlp.act.scale", "encoder.stages.2.3.mlp.act.bias", "encoder.stages.2.3.mlp.fc2.weight", "encoder.stages.2.3.res_scale2.scale", "encoder.stages.2.4.norm1.weight", "encoder.stages.2.4.token_mixer.qkv.weight", "encoder.stages.2.4.token_mixer.proj.weight", "encoder.stages.2.4.res_scale1.scale", "encoder.stages.2.4.norm2.weight", "encoder.stages.2.4.mlp.fc1.weight", "encoder.stages.2.4.mlp.act.scale", "encoder.stages.2.4.mlp.act.bias", "encoder.stages.2.4.mlp.fc2.weight", "encoder.stages.2.4.res_scale2.scale", "encoder.stages.2.5.norm1.weight", "encoder.stages.2.5.token_mixer.qkv.weight", "encoder.stages.2.5.token_mixer.proj.weight", "encoder.stages.2.5.res_scale1.scale", "encoder.stages.2.5.norm2.weight", "encoder.stages.2.5.mlp.fc1.weight", "encoder.stages.2.5.mlp.act.scale", "encoder.stages.2.5.mlp.act.bias", "encoder.stages.2.5.mlp.fc2.weight", "encoder.stages.2.5.res_scale2.scale", "encoder.stages.2.6.norm1.weight", "encoder.stages.2.6.token_mixer.qkv.weight", "encoder.stages.2.6.token_mixer.proj.weight", "encoder.stages.2.6.res_scale1.scale", "encoder.stages.2.6.norm2.weight", "encoder.stages.2.6.mlp.fc1.weight", "encoder.stages.2.6.mlp.act.scale", "encoder.stages.2.6.mlp.act.bias", "encoder.stages.2.6.mlp.fc2.weight", "encoder.stages.2.6.res_scale2.scale", "encoder.stages.2.7.norm1.weight", "encoder.stages.2.7.token_mixer.qkv.weight", "encoder.stages.2.7.token_mixer.proj.weight", "encoder.stages.2.7.res_scale1.scale", "encoder.stages.2.7.norm2.weight", "encoder.stages.2.7.mlp.fc1.weight", "encoder.stages.2.7.mlp.act.scale", "encoder.stages.2.7.mlp.act.bias", "encoder.stages.2.7.mlp.fc2.weight", "encoder.stages.2.7.res_scale2.scale", "encoder.stages.2.8.norm1.weight", "encoder.stages.2.8.token_mixer.qkv.weight", "encoder.stages.2.8.token_mixer.proj.weight", "encoder.stages.2.8.res_scale1.scale", "encoder.stages.2.8.norm2.weight", "encoder.stages.2.8.mlp.fc1.weight", "encoder.stages.2.8.mlp.act.scale", "encoder.stages.2.8.mlp.act.bias", "encoder.stages.2.8.mlp.fc2.weight", "encoder.stages.2.8.res_scale2.scale", "encoder.stages.2.9.norm1.weight", "encoder.stages.2.9.token_mixer.qkv.weight", "encoder.stages.2.9.token_mixer.proj.weight", "encoder.stages.2.9.res_scale1.scale", "encoder.stages.2.9.norm2.weight", "encoder.stages.2.9.mlp.fc1.weight", "encoder.stages.2.9.mlp.act.scale", "encoder.stages.2.9.mlp.act.bias", "encoder.stages.2.9.mlp.fc2.weight", "encoder.stages.2.9.res_scale2.scale", "encoder.stages.2.10.norm1.weight", "encoder.stages.2.10.token_mixer.qkv.weight", "encoder.stages.2.10.token_mixer.proj.weight", "encoder.stages.2.10.res_scale1.scale", "encoder.stages.2.10.norm2.weight", "encoder.stages.2.10.mlp.fc1.weight", "encoder.stages.2.10.mlp.act.scale", "encoder.stages.2.10.mlp.act.bias", "encoder.stages.2.10.mlp.fc2.weight", "encoder.stages.2.10.res_scale2.scale", "encoder.stages.2.11.norm1.weight", "encoder.stages.2.11.token_mixer.qkv.weight", "encoder.stages.2.11.token_mixer.proj.weight", "encoder.stages.2.11.res_scale1.scale", "encoder.stages.2.11.norm2.weight", "encoder.stages.2.11.mlp.fc1.weight", "encoder.stages.2.11.mlp.act.scale", "encoder.stages.2.11.mlp.act.bias", "encoder.stages.2.11.mlp.fc2.weight", "encoder.stages.2.11.res_scale2.scale", "encoder.stages.2.12.norm1.weight", "encoder.stages.2.12.token_mixer.qkv.weight", "encoder.stages.2.12.token_mixer.proj.weight", "encoder.stages.2.12.res_scale1.scale", "encoder.stages.2.12.norm2.weight", "encoder.stages.2.12.mlp.fc1.weight", "encoder.stages.2.12.mlp.act.scale", "encoder.stages.2.12.mlp.act.bias", "encoder.stages.2.12.mlp.fc2.weight", "encoder.stages.2.12.res_scale2.scale", "encoder.stages.2.13.norm1.weight", "encoder.stages.2.13.token_mixer.qkv.weight", "encoder.stages.2.13.token_mixer.proj.weight", "encoder.stages.2.13.res_scale1.scale", "encoder.stages.2.13.norm2.weight", "encoder.stages.2.13.mlp.fc1.weight", "encoder.stages.2.13.mlp.act.scale", "encoder.stages.2.13.mlp.act.bias", "encoder.stages.2.13.mlp.fc2.weight", "encoder.stages.2.13.res_scale2.scale", "encoder.stages.2.14.norm1.weight", "encoder.stages.2.14.token_mixer.qkv.weight", "encoder.stages.2.14.token_mixer.proj.weight", "encoder.stages.2.14.res_scale1.scale", "encoder.stages.2.14.norm2.weight", "encoder.stages.2.14.mlp.fc1.weight", "encoder.stages.2.14.mlp.act.scale", "encoder.stages.2.14.mlp.act.bias", "encoder.stages.2.14.mlp.fc2.weight", "encoder.stages.2.14.res_scale2.scale", "encoder.stages.2.15.norm1.weight", "encoder.stages.2.15.token_mixer.qkv.weight", "encoder.stages.2.15.token_mixer.proj.weight", "encoder.stages.2.15.res_scale1.scale", "encoder.stages.2.15.norm2.weight", "encoder.stages.2.15.mlp.fc1.weight", "encoder.stages.2.15.mlp.act.scale", "encoder.stages.2.15.mlp.act.bias", "encoder.stages.2.15.mlp.fc2.weight", "encoder.stages.2.15.res_scale2.scale", "encoder.stages.2.16.norm1.weight", "encoder.stages.2.16.token_mixer.qkv.weight", "encoder.stages.2.16.token_mixer.proj.weight", "encoder.stages.2.16.res_scale1.scale", "encoder.stages.2.16.norm2.weight", "encoder.stages.2.16.mlp.fc1.weight", "encoder.stages.2.16.mlp.act.scale", "encoder.stages.2.16.mlp.act.bias", "encoder.stages.2.16.mlp.fc2.weight", "encoder.stages.2.16.res_scale2.scale", "encoder.stages.2.17.norm1.weight", "encoder.stages.2.17.token_mixer.qkv.weight", "encoder.stages.2.17.token_mixer.proj.weight", "encoder.stages.2.17.res_scale1.scale", "encoder.stages.2.17.norm2.weight", "encoder.stages.2.17.mlp.fc1.weight", "encoder.stages.2.17.mlp.act.scale", "encoder.stages.2.17.mlp.act.bias", "encoder.stages.2.17.mlp.fc2.weight", "encoder.stages.2.17.res_scale2.scale", "decoder.convs.conv0_0.conv1.0.weight", "decoder.convs.conv0_0.conv2.0.weight", "decoder.convs.conv1_0.conv1.0.weight", "decoder.convs.conv1_0.conv2.0.weight", "decoder.convs.conv2_0.conv1.0.weight", "decoder.convs.conv2_0.conv2.0.weight", "decoder.convs.conv3_0.conv1.0.weight", "decoder.convs.conv3_0.conv2.0.weight", "decoder.convs.conv0_1.conv1.0.weight", "decoder.convs.conv0_1.conv2.0.weight", "decoder.convs.conv1_1.conv1.0.weight", "decoder.convs.conv1_1.conv2.0.weight", "decoder.convs.conv2_1.conv1.0.weight", "decoder.convs.conv2_1.conv2.0.weight", "decoder.convs.conv0_2.conv1.0.weight", "decoder.convs.conv0_2.conv2.0.weight", "decoder.convs.conv1_2.conv1.0.weight", "decoder.convs.conv1_2.conv2.0.weight", "decoder.convs.conv0_3.conv1.0.weight", "decoder.convs.conv0_3.conv2.0.weight", "head.final.0.weight", "head.final.0.bias". Unexpected key(s) in state_dict: "downsample_layers.0.conv.weight", "downsample_layers.0.conv.bias", "downsample_layers.0.post_norm.weight", "downsample_layers.1.pre_norm.weight", "downsample_layers.1.conv.weight", "downsample_layers.1.conv.bias", "downsample_layers.2.pre_norm.weight", "downsample_layers.2.conv.weight", "downsample_layers.2.conv.bias", "downsample_layers.3.pre_norm.weight", "downsample_layers.3.conv.weight", "downsample_layers.3.conv.bias", "stages.0.0.norm1.weight", "stages.0.0.token_mixer.pwconv1.weight", "stages.0.0.token_mixer.act1.scale", "stages.0.0.token_mixer.act1.bias", "stages.0.0.token_mixer.dwconv.weight", "stages.0.0.token_mixer.pwconv2.weight", "stages.0.0.norm2.weight", "stages.0.0.mlp.fc1.weight", "stages.0.0.mlp.act.scale", "stages.0.0.mlp.act.bias", "stages.0.0.mlp.fc2.weight", "stages.0.1.norm1.weight", "stages.0.1.token_mixer.pwconv1.weight", "stages.0.1.token_mixer.act1.scale", "stages.0.1.token_mixer.act1.bias", "stages.0.1.token_mixer.dwconv.weight", "stages.0.1.token_mixer.pwconv2.weight", "stages.0.1.norm2.weight", "stages.0.1.mlp.fc1.weight", "stages.0.1.mlp.act.scale", "stages.0.1.mlp.act.bias", "stages.0.1.mlp.fc2.weight", "stages.0.2.norm1.weight", "stages.0.2.token_mixer.pwconv1.weight", "stages.0.2.token_mixer.act1.scale", "stages.0.2.token_mixer.act1.bias", "stages.0.2.token_mixer.dwconv.weight", "stages.0.2.token_mixer.pwconv2.weight", "stages.0.2.norm2.weight", "stages.0.2.mlp.fc1.weight", "stages.0.2.mlp.act.scale", "stages.0.2.mlp.act.bias", "stages.0.2.mlp.fc2.weight", "stages.1.0.norm1.weight", "stages.1.0.token_mixer.pwconv1.weight", "stages.1.0.token_mixer.act1.scale", "stages.1.0.token_mixer.act1.bias", "stages.1.0.token_mixer.dwconv.weight", "stages.1.0.token_mixer.pwconv2.weight", "stages.1.0.norm2.weight", "stages.1.0.mlp.fc1.weight", "stages.1.0.mlp.act.scale", "stages.1.0.mlp.act.bias", "stages.1.0.mlp.fc2.weight", "stages.1.1.norm1.weight", "stages.1.1.token_mixer.pwconv1.weight", "stages.1.1.token_mixer.act1.scale", "stages.1.1.token_mixer.act1.bias", "stages.1.1.token_mixer.dwconv.weight", "stages.1.1.token_mixer.pwconv2.weight", "stages.1.1.norm2.weight", "stages.1.1.mlp.fc1.weight", "stages.1.1.mlp.act.scale", "stages.1.1.mlp.act.bias", "stages.1.1.mlp.fc2.weight", "stages.1.2.norm1.weight", "stages.1.2.token_mixer.pwconv1.weight", "stages.1.2.token_mixer.act1.scale", "stages.1.2.token_mixer.act1.bias", "stages.1.2.token_mixer.dwconv.weight", "stages.1.2.token_mixer.pwconv2.weight", "stages.1.2.norm2.weight", "stages.1.2.mlp.fc1.weight", "stages.1.2.mlp.act.scale", "stages.1.2.mlp.act.bias", "stages.1.2.mlp.fc2.weight", "stages.1.3.norm1.weight", "stages.1.3.token_mixer.pwconv1.weight", "stages.1.3.token_mixer.act1.scale", "stages.1.3.token_mixer.act1.bias", "stages.1.3.token_mixer.dwconv.weight", "stages.1.3.token_mixer.pwconv2.weight", "stages.1.3.norm2.weight", "stages.1.3.mlp.fc1.weight", "stages.1.3.mlp.act.scale", "stages.1.3.mlp.act.bias", "stages.1.3.mlp.fc2.weight", "stages.1.4.norm1.weight", "stages.1.4.token_mixer.pwconv1.weight", "stages.1.4.token_mixer.act1.scale", "stages.1.4.token_mixer.act1.bias", "stages.1.4.token_mixer.dwconv.weight", "stages.1.4.token_mixer.pwconv2.weight", "stages.1.4.norm2.weight", "stages.1.4.mlp.fc1.weight", "stages.1.4.mlp.act.scale", "stages.1.4.mlp.act.bias", "stages.1.4.mlp.fc2.weight", "stages.1.5.norm1.weight", "stages.1.5.token_mixer.pwconv1.weight", "stages.1.5.token_mixer.act1.scale", "stages.1.5.token_mixer.act1.bias", "stages.1.5.token_mixer.dwconv.weight", "stages.1.5.token_mixer.pwconv2.weight", "stages.1.5.norm2.weight", "stages.1.5.mlp.fc1.weight", "stages.1.5.mlp.act.scale", "stages.1.5.mlp.act.bias", "stages.1.5.mlp.fc2.weight", "stages.1.6.norm1.weight", "stages.1.6.token_mixer.pwconv1.weight", "stages.1.6.token_mixer.act1.scale", "stages.1.6.token_mixer.act1.bias", "stages.1.6.token_mixer.dwconv.weight", "stages.1.6.token_mixer.pwconv2.weight", "stages.1.6.norm2.weight", "stages.1.6.mlp.fc1.weight", "stages.1.6.mlp.act.scale", "stages.1.6.mlp.act.bias", "stages.1.6.mlp.fc2.weight", "stages.1.7.norm1.weight", "stages.1.7.token_mixer.pwconv1.weight", "stages.1.7.token_mixer.act1.scale", "stages.1.7.token_mixer.act1.bias", "stages.1.7.token_mixer.dwconv.weight", "stages.1.7.token_mixer.pwconv2.weight", "stages.1.7.norm2.weight", "stages.1.7.mlp.fc1.weight", "stages.1.7.mlp.act.scale", "stages.1.7.mlp.act.bias", "stages.1.7.mlp.fc2.weight", "stages.1.8.norm1.weight", "stages.1.8.token_mixer.pwconv1.weight", "stages.1.8.token_mixer.act1.scale", "stages.1.8.token_mixer.act1.bias", "stages.1.8.token_mixer.dwconv.weight", "stages.1.8.token_mixer.pwconv2.weight", "stages.1.8.norm2.weight", "stages.1.8.mlp.fc1.weight", "stages.1.8.mlp.act.scale", "stages.1.8.mlp.act.bias", "stages.1.8.mlp.fc2.weight", "stages.1.9.norm1.weight", "stages.1.9.token_mixer.pwconv1.weight", "stages.1.9.token_mixer.act1.scale", "stages.1.9.token_mixer.act1.bias", "stages.1.9.token_mixer.dwconv.weight", "stages.1.9.token_mixer.pwconv2.weight", "stages.1.9.norm2.weight", "stages.1.9.mlp.fc1.weight", "stages.1.9.mlp.act.scale", "stages.1.9.mlp.act.bias", "stages.1.9.mlp.fc2.weight", "stages.1.10.norm1.weight", "stages.1.10.token_mixer.pwconv1.weight", "stages.1.10.token_mixer.act1.scale", "stages.1.10.token_mixer.act1.bias", "stages.1.10.token_mixer.dwconv.weight", "stages.1.10.token_mixer.pwconv2.weight", "stages.1.10.norm2.weight", "stages.1.10.mlp.fc1.weight", "stages.1.10.mlp.act.scale", "stages.1.10.mlp.act.bias", "stages.1.10.mlp.fc2.weight", "stages.1.11.norm1.weight", "stages.1.11.token_mixer.pwconv1.weight", "stages.1.11.token_mixer.act1.scale", "stages.1.11.token_mixer.act1.bias", "stages.1.11.token_mixer.dwconv.weight", "stages.1.11.token_mixer.pwconv2.weight", "stages.1.11.norm2.weight", "stages.1.11.mlp.fc1.weight", "stages.1.11.mlp.act.scale", "stages.1.11.mlp.act.bias", "stages.1.11.mlp.fc2.weight", "stages.2.0.norm1.weight", "stages.2.0.token_mixer.qkv.weight", "stages.2.0.token_mixer.proj.weight", "stages.2.0.res_scale1.scale", "stages.2.0.norm2.weight", "stages.2.0.mlp.fc1.weight", "stages.2.0.mlp.act.scale", "stages.2.0.mlp.act.bias", "stages.2.0.mlp.fc2.weight", "stages.2.0.res_scale2.scale", "stages.2.1.norm1.weight", "stages.2.1.token_mixer.qkv.weight", "stages.2.1.token_mixer.proj.weight", "stages.2.1.res_scale1.scale", "stages.2.1.norm2.weight", "stages.2.1.mlp.fc1.weight", "stages.2.1.mlp.act.scale", "stages.2.1.mlp.act.bias", "stages.2.1.mlp.fc2.weight", "stages.2.1.res_scale2.scale", "stages.2.2.norm1.weight", "stages.2.2.token_mixer.qkv.weight", "stages.2.2.token_mixer.proj.weight", "stages.2.2.res_scale1.scale", "stages.2.2.norm2.weight", "stages.2.2.mlp.fc1.weight", "stages.2.2.mlp.act.scale", "stages.2.2.mlp.act.bias", "stages.2.2.mlp.fc2.weight", "stages.2.2.res_scale2.scale", "stages.2.3.norm1.weight", "stages.2.3.token_mixer.qkv.weight", "stages.2.3.token_mixer.proj.weight", "stages.2.3.res_scale1.scale", "stages.2.3.norm2.weight", "stages.2.3.mlp.fc1.weight", "stages.2.3.mlp.act.scale", "stages.2.3.mlp.act.bias", "stages.2.3.mlp.fc2.weight", "stages.2.3.res_scale2.scale", "stages.2.4.norm1.weight", "stages.2.4.token_mixer.qkv.weight", "stages.2.4.token_mixer.proj.weight", "stages.2.4.res_scale1.scale", "stages.2.4.norm2.weight", "stages.2.4.mlp.fc1.weight", "stages.2.4.mlp.act.scale", "stages.2.4.mlp.act.bias", "stages.2.4.mlp.fc2.weight", "stages.2.4.res_scale2.scale", "stages.2.5.norm1.weight", "stages.2.5.token_mixer.qkv.weight", "stages.2.5.token_mixer.proj.weight", "stages.2.5.res_scale1.scale", "stages.2.5.norm2.weight", "stages.2.5.mlp.fc1.weight", "stages.2.5.mlp.act.scale", "stages.2.5.mlp.act.bias", "stages.2.5.mlp.fc2.weight", "stages.2.5.res_scale2.scale", "stages.2.6.norm1.weight", "stages.2.6.token_mixer.qkv.weight", "stages.2.6.token_mixer.proj.weight", "stages.2.6.res_scale1.scale", "stages.2.6.norm2.weight", "stages.2.6.mlp.fc1.weight", "stages.2.6.mlp.act.scale", "stages.2.6.mlp.act.bias", "stages.2.6.mlp.fc2.weight", "stages.2.6.res_scale2.scale", "stages.2.7.norm1.weight", "stages.2.7.token_mixer.qkv.weight", "stages.2.7.token_mixer.proj.weight", "stages.2.7.res_scale1.scale", "stages.2.7.norm2.weight", "stages.2.7.mlp.fc1.weight", "stages.2.7.mlp.act.scale", "stages.2.7.mlp.act.bias", "stages.2.7.mlp.fc2.weight", "stages.2.7.res_scale2.scale", "stages.2.8.norm1.weight", "stages.2.8.token_mixer.qkv.weight", "stages.2.8.token_mixer.proj.weight", "stages.2.8.res_scale1.scale", "stages.2.8.norm2.weight", "stages.2.8.mlp.fc1.weight", "stages.2.8.mlp.act.scale", "stages.2.8.mlp.act.bias", "stages.2.8.mlp.fc2.weight", "stages.2.8.res_scale2.scale", "stages.2.9.norm1.weight", "stages.2.9.token_mixer.qkv.weight", "stages.2.9.token_mixer.proj.weight", "stages.2.9.res_scale1.scale", "stages.2.9.norm2.weight", "stages.2.9.mlp.fc1.weight", "stages.2.9.mlp.act.scale", "stages.2.9.mlp.act.bias", "stages.2.9.mlp.fc2.weight", "stages.2.9.res_scale2.scale", "stages.2.10.norm1.weight", "stages.2.10.token_mixer.qkv.weight", "stages.2.10.token_mixer.proj.weight", "stages.2.10.res_scale1.scale", "stages.2.10.norm2.weight", "stages.2.10.mlp.fc1.weight", "stages.2.10.mlp.act.scale", "stages.2.10.mlp.act.bias", "stages.2.10.mlp.fc2.weight", "stages.2.10.res_scale2.scale", "stages.2.11.norm1.weight", "stages.2.11.token_mixer.qkv.weight", "stages.2.11.token_mixer.proj.weight", "stages.2.11.res_scale1.scale", "stages.2.11.norm2.weight", "stages.2.11.mlp.fc1.weight", "stages.2.11.mlp.act.scale", "stages.2.11.mlp.act.bias", "stages.2.11.mlp.fc2.weight", "stages.2.11.res_scale2.scale", "stages.2.12.norm1.weight", "stages.2.12.token_mixer.qkv.weight", "stages.2.12.token_mixer.proj.weight", "stages.2.12.res_scale1.scale", "stages.2.12.norm2.weight", "stages.2.12.mlp.fc1.weight", "stages.2.12.mlp.act.scale", "stages.2.12.mlp.act.bias", "stages.2.12.mlp.fc2.weight", "stages.2.12.res_scale2.scale", "stages.2.13.norm1.weight", "stages.2.13.token_mixer.qkv.weight", "stages.2.13.token_mixer.proj.weight", "stages.2.13.res_scale1.scale", "stages.2.13.norm2.weight", "stages.2.13.mlp.fc1.weight", "stages.2.13.mlp.act.scale", "stages.2.13.mlp.act.bias", "stages.2.13.mlp.fc2.weight", "stages.2.13.res_scale2.scale", "stages.2.14.norm1.weight", "stages.2.14.token_mixer.qkv.weight", "stages.2.14.token_mixer.proj.weight", "stages.2.14.res_scale1.scale", "stages.2.14.norm2.weight", "stages.2.14.mlp.fc1.weight", "stages.2.14.mlp.act.scale", "stages.2.14.mlp.act.bias", "stages.2.14.mlp.fc2.weight", "stages.2.14.res_scale2.scale", "stages.2.15.norm1.weight", "stages.2.15.token_mixer.qkv.weight", "stages.2.15.token_mixer.proj.weight", "stages.2.15.res_scale1.scale", "stages.2.15.norm2.weight", "stages.2.15.mlp.fc1.weight", "stages.2.15.mlp.act.scale", "stages.2.15.mlp.act.bias", "stages.2.15.mlp.fc2.weight", "stages.2.15.res_scale2.scale", "stages.2.16.norm1.weight", "stages.2.16.token_mixer.qkv.weight", "stages.2.16.token_mixer.proj.weight", "stages.2.16.res_scale1.scale", "stages.2.16.norm2.weight", "stages.2.16.mlp.fc1.weight", "stages.2.16.mlp.act.scale", "stages.2.16.mlp.act.bias", "stages.2.16.mlp.fc2.weight", "stages.2.16.res_scale2.scale", "stages.2.17.norm1.weight", "stages.2.17.token_mixer.qkv.weight", "stages.2.17.token_mixer.proj.weight", "stages.2.17.res_scale1.scale", "stages.2.17.norm2.weight", "stages.2.17.mlp.fc1.weight", "stages.2.17.mlp.act.scale", "stages.2.17.mlp.act.bias", "stages.2.17.mlp.fc2.weight", "stages.2.17.res_scale2.scale", "stages.3.0.norm1.weight", "stages.3.0.token_mixer.qkv.weight", "stages.3.0.token_mixer.proj.weight", "stages.3.0.res_scale1.scale", "stages.3.0.norm2.weight", "stages.3.0.mlp.fc1.weight", "stages.3.0.mlp.act.scale", "stages.3.0.mlp.act.bias", "stages.3.0.mlp.fc2.weight", "stages.3.0.res_scale2.scale", "stages.3.1.norm1.weight", "stages.3.1.token_mixer.qkv.weight", "stages.3.1.token_mixer.proj.weight", "stages.3.1.res_scale1.scale", "stages.3.1.norm2.weight", "stages.3.1.mlp.fc1.weight", "stages.3.1.mlp.act.scale", "stages.3.1.mlp.act.bias", "stages.3.1.mlp.fc2.weight", "stages.3.1.res_scale2.scale", "stages.3.2.norm1.weight", "stages.3.2.token_mixer.qkv.weight", "stages.3.2.token_mixer.proj.weight", "stages.3.2.res_scale1.scale", "stages.3.2.norm2.weight", "stages.3.2.mlp.fc1.weight", "stages.3.2.mlp.act.scale", "stages.3.2.mlp.act.bias", "stages.3.2.mlp.fc2.weight", "stages.3.2.res_scale2.scale", "norm.weight", "norm.bias", "head.fc1.weight", "head.fc1.bias", "head.norm.weight", "head.norm.bias", "head.fc2.weight", "head.fc2.bias".
This error is related to the command model.load_state_dict(ckpt).

Before running this script I downloaded the checkpoint you mentioned in README: https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth

I have seen that in main.py script you load the model using checkpoint like this: if args.resume is not None: ckpt = torch.load(args.resume, map_location='cpu')['state_dict'] ckpt["encoder.conv2.0.weight"] = ckpt.pop("encoder.conv2.1.weight") ckpt["encoder.conv2.0.bias"] = ckpt.pop("encoder.conv2.1.bias") model.load_state_dict(ckpt) print("load pretrained model, successfully!")
However, I cannot load the model this way as I'm getting error saying that there is no key 'state_dict' for ckpt object. Also, I cannot pop "encoder.conv2.1.weight" and "encoder.conv2.1.bias" since there are no such keys in ckpt object.

Here are my python libs versions inside virtual environment: Ubuntu 22.04 Python 3.10.12 Package Version ------------------------ ---------- certifi 2024.8.30 charset-normalizer 3.4.0 filelock 3.16.1 fsspec 2024.10.0 huggingface-hub 0.26.0 idna 3.10 Jinja2 3.1.4 MarkupSafe 3.0.2 mpmath 1.3.0 networkx 3.4.1 numpy 2.1.2 nvidia-cublas-cu12 12.4.5.8 nvidia-cuda-cupti-cu12 12.4.127 nvidia-cuda-nvrtc-cu12 12.4.127 nvidia-cuda-runtime-cu12 12.4.127 nvidia-cudnn-cu12 9.1.0.70 nvidia-cufft-cu12 11.2.1.3 nvidia-curand-cu12 10.3.5.147 nvidia-cusolver-cu12 11.6.1.9 nvidia-cusparse-cu12 12.3.1.170 nvidia-nccl-cu12 2.21.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.4.127 packaging 24.1 pillow 11.0.0 pip 22.0.2 PyYAML 6.0.2 requests 2.32.3 safetensors 0.4.5 setuptools 59.6.0 sympy 1.13.1 timm 1.0.11 torch 2.5.0 torchvision 0.20.0 tqdm 4.66.5 triton 3.1.0 typing_extensions 4.12.2 urllib3 2.2.3
Could you please provide us with a separate inference script that we can easily run on an arbitrary input image using the pretrained weights you mentioned?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>