Open nikolamilovic-ft opened 1 month ago
what is the error?
---Original--- From: "Nikola @.> Date: Mon, Oct 21, 2024 11:00 AM To: @.>; Cc: @.***>; Subject: [Li-yachuan/NBED] Inference script (Issue #3)
Could you add an easy to run inference script in the repository? I'm having trouble running the edge extraction on arbitrary input image using the checkpoint model you provided in the README.
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you are subscribed to this thread.Message ID: @.***>
Here is _myinference.py script:
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import argparse
import os
# Import the model definition
from model.basemodel import Basemodel
def load_model(args, device):
# Initialize the model
model = Basemodel(encoder_name=args.encoder,
decoder_name=args.decoder,
head_name=args.head).to(device)
# Load the pretrained weights
if args.resume is not None:
ckpt = torch.load(args.resume, weights_only=True, map_location=device)
if 'state_dict' in ckpt:
ckpt = ckpt['state_dict']
model.load_state_dict(ckpt)
else:
print("No pretrained weights provided. Using untrained model.")
model.eval()
return model
def preprocess_image(image_path, device):
# Define the image transformations
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
image = image.to(device)
return image
def run_inference(model, input_image, device):
with torch.no_grad():
output = model(input_image)
if isinstance(output, (tuple, list)):
output = output[0]
# Process the output tensor to create an image
output = output.squeeze().cpu().numpy()
# Normalize the output to [0, 1]
output = (output - output.min()) / (output.max() - output.min())
output_image = Image.fromarray((output * 255).astype(np.uint8))
return output_image
def main():
parser = argparse.ArgumentParser(description='NBED Edge Detection Inference')
parser.add_argument('--input', type=str, required=True, help='Path to the input image file')
parser.add_argument('--output', type=str, default='output.png', help='Path to save the output image')
parser.add_argument('--resume', type=str, default=None, help='Path to the pretrained model weights (.pth file)')
parser.add_argument("--encoder", default="Dul-M36",
help="Options: caformer-m36, Dul-M36")
parser.add_argument("--decoder", default="unetp",
help="Options: unet, unetp, default")
parser.add_argument("--head", default="default",
help="Options: default, aspp, atten, cofusion")
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the model
model = load_model(args, device)
# Preprocess the input image
input_image = preprocess_image(args.input, device)
# Run inference
output_image = run_inference(model, input_image, device)
# Save the output image
output_image.save(args.output)
print(f'Edge-detected image saved to {args.output}')
if __name__ == '__main__':
main()
# Run command: python3 my_inference.py --input Imgs/test_assembly.png --output outputs/edge_output.png --resume model/caformer_m36_384_in21ft1k.pth
And this is the error message I get on model loading:
RuntimeError: Error(s) in loading state_dict for Basemodel:
Missing key(s) in state_dict: "encoder.conv1.0.weight", "encoder.conv1.0.bias", "encoder.conv2.0.weight", "encoder.conv2.0.bias", "encoder.downsample_layers.0.conv.weight", "encoder.downsample_layers.0.conv.bias", "encoder.downsample_layers.0.post_norm.weight", "encoder.downsample_layers.1.pre_norm.weight", "encoder.downsample_layers.1.conv.weight", "encoder.downsample_layers.1.conv.bias", "encoder.downsample_layers.2.pre_norm.weight", "encoder.downsample_layers.2.conv.weight", "encoder.downsample_layers.2.conv.bias", "encoder.stages.0.0.norm1.weight", "encoder.stages.0.0.token_mixer.pwconv1.weight", "encoder.stages.0.0.token_mixer.act1.scale", "encoder.stages.0.0.token_mixer.act1.bias", "encoder.stages.0.0.token_mixer.dwconv.weight", "encoder.stages.0.0.token_mixer.pwconv2.weight", "encoder.stages.0.0.norm2.weight", "encoder.stages.0.0.mlp.fc1.weight", "encoder.stages.0.0.mlp.act.scale", "encoder.stages.0.0.mlp.act.bias", "encoder.stages.0.0.mlp.fc2.weight", "encoder.stages.0.1.norm1.weight", "encoder.stages.0.1.token_mixer.pwconv1.weight", "encoder.stages.0.1.token_mixer.act1.scale", "encoder.stages.0.1.token_mixer.act1.bias", "encoder.stages.0.1.token_mixer.dwconv.weight", "encoder.stages.0.1.token_mixer.pwconv2.weight", "encoder.stages.0.1.norm2.weight", "encoder.stages.0.1.mlp.fc1.weight", "encoder.stages.0.1.mlp.act.scale", "encoder.stages.0.1.mlp.act.bias", "encoder.stages.0.1.mlp.fc2.weight", "encoder.stages.0.2.norm1.weight", "encoder.stages.0.2.token_mixer.pwconv1.weight", "encoder.stages.0.2.token_mixer.act1.scale", "encoder.stages.0.2.token_mixer.act1.bias", "encoder.stages.0.2.token_mixer.dwconv.weight", "encoder.stages.0.2.token_mixer.pwconv2.weight", "encoder.stages.0.2.norm2.weight", "encoder.stages.0.2.mlp.fc1.weight", "encoder.stages.0.2.mlp.act.scale", "encoder.stages.0.2.mlp.act.bias", "encoder.stages.0.2.mlp.fc2.weight", "encoder.stages.1.0.norm1.weight", "encoder.stages.1.0.token_mixer.pwconv1.weight", "encoder.stages.1.0.token_mixer.act1.scale", "encoder.stages.1.0.token_mixer.act1.bias", "encoder.stages.1.0.token_mixer.dwconv.weight", "encoder.stages.1.0.token_mixer.pwconv2.weight", "encoder.stages.1.0.norm2.weight", "encoder.stages.1.0.mlp.fc1.weight", "encoder.stages.1.0.mlp.act.scale", "encoder.stages.1.0.mlp.act.bias", "encoder.stages.1.0.mlp.fc2.weight", "encoder.stages.1.1.norm1.weight", "encoder.stages.1.1.token_mixer.pwconv1.weight", "encoder.stages.1.1.token_mixer.act1.scale", "encoder.stages.1.1.token_mixer.act1.bias", "encoder.stages.1.1.token_mixer.dwconv.weight", "encoder.stages.1.1.token_mixer.pwconv2.weight", "encoder.stages.1.1.norm2.weight", "encoder.stages.1.1.mlp.fc1.weight", "encoder.stages.1.1.mlp.act.scale", "encoder.stages.1.1.mlp.act.bias", "encoder.stages.1.1.mlp.fc2.weight", "encoder.stages.1.2.norm1.weight", "encoder.stages.1.2.token_mixer.pwconv1.weight", "encoder.stages.1.2.token_mixer.act1.scale", "encoder.stages.1.2.token_mixer.act1.bias", "encoder.stages.1.2.token_mixer.dwconv.weight", "encoder.stages.1.2.token_mixer.pwconv2.weight", "encoder.stages.1.2.norm2.weight", "encoder.stages.1.2.mlp.fc1.weight", "encoder.stages.1.2.mlp.act.scale", "encoder.stages.1.2.mlp.act.bias", "encoder.stages.1.2.mlp.fc2.weight", "encoder.stages.1.3.norm1.weight", "encoder.stages.1.3.token_mixer.pwconv1.weight", "encoder.stages.1.3.token_mixer.act1.scale", "encoder.stages.1.3.token_mixer.act1.bias", "encoder.stages.1.3.token_mixer.dwconv.weight", "encoder.stages.1.3.token_mixer.pwconv2.weight", "encoder.stages.1.3.norm2.weight", "encoder.stages.1.3.mlp.fc1.weight", "encoder.stages.1.3.mlp.act.scale", "encoder.stages.1.3.mlp.act.bias", "encoder.stages.1.3.mlp.fc2.weight", "encoder.stages.1.4.norm1.weight", "encoder.stages.1.4.token_mixer.pwconv1.weight", "encoder.stages.1.4.token_mixer.act1.scale", "encoder.stages.1.4.token_mixer.act1.bias", "encoder.stages.1.4.token_mixer.dwconv.weight", "encoder.stages.1.4.token_mixer.pwconv2.weight", "encoder.stages.1.4.norm2.weight", "encoder.stages.1.4.mlp.fc1.weight", "encoder.stages.1.4.mlp.act.scale", "encoder.stages.1.4.mlp.act.bias", "encoder.stages.1.4.mlp.fc2.weight", "encoder.stages.1.5.norm1.weight", "encoder.stages.1.5.token_mixer.pwconv1.weight", "encoder.stages.1.5.token_mixer.act1.scale", "encoder.stages.1.5.token_mixer.act1.bias", "encoder.stages.1.5.token_mixer.dwconv.weight", "encoder.stages.1.5.token_mixer.pwconv2.weight", "encoder.stages.1.5.norm2.weight", "encoder.stages.1.5.mlp.fc1.weight", "encoder.stages.1.5.mlp.act.scale", "encoder.stages.1.5.mlp.act.bias", "encoder.stages.1.5.mlp.fc2.weight", "encoder.stages.1.6.norm1.weight", "encoder.stages.1.6.token_mixer.pwconv1.weight", "encoder.stages.1.6.token_mixer.act1.scale", "encoder.stages.1.6.token_mixer.act1.bias", "encoder.stages.1.6.token_mixer.dwconv.weight", "encoder.stages.1.6.token_mixer.pwconv2.weight", "encoder.stages.1.6.norm2.weight", "encoder.stages.1.6.mlp.fc1.weight", "encoder.stages.1.6.mlp.act.scale", "encoder.stages.1.6.mlp.act.bias", "encoder.stages.1.6.mlp.fc2.weight", "encoder.stages.1.7.norm1.weight", "encoder.stages.1.7.token_mixer.pwconv1.weight", "encoder.stages.1.7.token_mixer.act1.scale", "encoder.stages.1.7.token_mixer.act1.bias", "encoder.stages.1.7.token_mixer.dwconv.weight", "encoder.stages.1.7.token_mixer.pwconv2.weight", "encoder.stages.1.7.norm2.weight", "encoder.stages.1.7.mlp.fc1.weight", "encoder.stages.1.7.mlp.act.scale", "encoder.stages.1.7.mlp.act.bias", "encoder.stages.1.7.mlp.fc2.weight", "encoder.stages.1.8.norm1.weight", "encoder.stages.1.8.token_mixer.pwconv1.weight", "encoder.stages.1.8.token_mixer.act1.scale", "encoder.stages.1.8.token_mixer.act1.bias", "encoder.stages.1.8.token_mixer.dwconv.weight", "encoder.stages.1.8.token_mixer.pwconv2.weight", "encoder.stages.1.8.norm2.weight", "encoder.stages.1.8.mlp.fc1.weight", "encoder.stages.1.8.mlp.act.scale", "encoder.stages.1.8.mlp.act.bias", "encoder.stages.1.8.mlp.fc2.weight", "encoder.stages.1.9.norm1.weight", "encoder.stages.1.9.token_mixer.pwconv1.weight", "encoder.stages.1.9.token_mixer.act1.scale", "encoder.stages.1.9.token_mixer.act1.bias", "encoder.stages.1.9.token_mixer.dwconv.weight", "encoder.stages.1.9.token_mixer.pwconv2.weight", "encoder.stages.1.9.norm2.weight", "encoder.stages.1.9.mlp.fc1.weight", "encoder.stages.1.9.mlp.act.scale", "encoder.stages.1.9.mlp.act.bias", "encoder.stages.1.9.mlp.fc2.weight", "encoder.stages.1.10.norm1.weight", "encoder.stages.1.10.token_mixer.pwconv1.weight", "encoder.stages.1.10.token_mixer.act1.scale", "encoder.stages.1.10.token_mixer.act1.bias", "encoder.stages.1.10.token_mixer.dwconv.weight", "encoder.stages.1.10.token_mixer.pwconv2.weight", "encoder.stages.1.10.norm2.weight", "encoder.stages.1.10.mlp.fc1.weight", "encoder.stages.1.10.mlp.act.scale", "encoder.stages.1.10.mlp.act.bias", "encoder.stages.1.10.mlp.fc2.weight", "encoder.stages.1.11.norm1.weight", "encoder.stages.1.11.token_mixer.pwconv1.weight", "encoder.stages.1.11.token_mixer.act1.scale", "encoder.stages.1.11.token_mixer.act1.bias", "encoder.stages.1.11.token_mixer.dwconv.weight", "encoder.stages.1.11.token_mixer.pwconv2.weight", "encoder.stages.1.11.norm2.weight", "encoder.stages.1.11.mlp.fc1.weight", "encoder.stages.1.11.mlp.act.scale", "encoder.stages.1.11.mlp.act.bias", "encoder.stages.1.11.mlp.fc2.weight", "encoder.stages.2.0.norm1.weight", "encoder.stages.2.0.token_mixer.qkv.weight", "encoder.stages.2.0.token_mixer.proj.weight", "encoder.stages.2.0.res_scale1.scale", "encoder.stages.2.0.norm2.weight", "encoder.stages.2.0.mlp.fc1.weight", "encoder.stages.2.0.mlp.act.scale", "encoder.stages.2.0.mlp.act.bias", "encoder.stages.2.0.mlp.fc2.weight", "encoder.stages.2.0.res_scale2.scale", "encoder.stages.2.1.norm1.weight", "encoder.stages.2.1.token_mixer.qkv.weight", "encoder.stages.2.1.token_mixer.proj.weight", "encoder.stages.2.1.res_scale1.scale", "encoder.stages.2.1.norm2.weight", "encoder.stages.2.1.mlp.fc1.weight", "encoder.stages.2.1.mlp.act.scale", "encoder.stages.2.1.mlp.act.bias", "encoder.stages.2.1.mlp.fc2.weight", "encoder.stages.2.1.res_scale2.scale", "encoder.stages.2.2.norm1.weight", "encoder.stages.2.2.token_mixer.qkv.weight", "encoder.stages.2.2.token_mixer.proj.weight", "encoder.stages.2.2.res_scale1.scale", "encoder.stages.2.2.norm2.weight", "encoder.stages.2.2.mlp.fc1.weight", "encoder.stages.2.2.mlp.act.scale", "encoder.stages.2.2.mlp.act.bias", "encoder.stages.2.2.mlp.fc2.weight", "encoder.stages.2.2.res_scale2.scale", "encoder.stages.2.3.norm1.weight", "encoder.stages.2.3.token_mixer.qkv.weight", "encoder.stages.2.3.token_mixer.proj.weight", "encoder.stages.2.3.res_scale1.scale", "encoder.stages.2.3.norm2.weight", "encoder.stages.2.3.mlp.fc1.weight", "encoder.stages.2.3.mlp.act.scale", "encoder.stages.2.3.mlp.act.bias", "encoder.stages.2.3.mlp.fc2.weight", "encoder.stages.2.3.res_scale2.scale", "encoder.stages.2.4.norm1.weight", "encoder.stages.2.4.token_mixer.qkv.weight", "encoder.stages.2.4.token_mixer.proj.weight", "encoder.stages.2.4.res_scale1.scale", "encoder.stages.2.4.norm2.weight", "encoder.stages.2.4.mlp.fc1.weight", "encoder.stages.2.4.mlp.act.scale", "encoder.stages.2.4.mlp.act.bias", "encoder.stages.2.4.mlp.fc2.weight", "encoder.stages.2.4.res_scale2.scale", "encoder.stages.2.5.norm1.weight", "encoder.stages.2.5.token_mixer.qkv.weight", "encoder.stages.2.5.token_mixer.proj.weight", "encoder.stages.2.5.res_scale1.scale", "encoder.stages.2.5.norm2.weight", "encoder.stages.2.5.mlp.fc1.weight", "encoder.stages.2.5.mlp.act.scale", "encoder.stages.2.5.mlp.act.bias", "encoder.stages.2.5.mlp.fc2.weight", "encoder.stages.2.5.res_scale2.scale", "encoder.stages.2.6.norm1.weight", "encoder.stages.2.6.token_mixer.qkv.weight", "encoder.stages.2.6.token_mixer.proj.weight", "encoder.stages.2.6.res_scale1.scale", "encoder.stages.2.6.norm2.weight", "encoder.stages.2.6.mlp.fc1.weight", "encoder.stages.2.6.mlp.act.scale", "encoder.stages.2.6.mlp.act.bias", "encoder.stages.2.6.mlp.fc2.weight", "encoder.stages.2.6.res_scale2.scale", "encoder.stages.2.7.norm1.weight", "encoder.stages.2.7.token_mixer.qkv.weight", "encoder.stages.2.7.token_mixer.proj.weight", "encoder.stages.2.7.res_scale1.scale", "encoder.stages.2.7.norm2.weight", "encoder.stages.2.7.mlp.fc1.weight", "encoder.stages.2.7.mlp.act.scale", "encoder.stages.2.7.mlp.act.bias", "encoder.stages.2.7.mlp.fc2.weight", "encoder.stages.2.7.res_scale2.scale", "encoder.stages.2.8.norm1.weight", "encoder.stages.2.8.token_mixer.qkv.weight", "encoder.stages.2.8.token_mixer.proj.weight", "encoder.stages.2.8.res_scale1.scale", "encoder.stages.2.8.norm2.weight", "encoder.stages.2.8.mlp.fc1.weight", "encoder.stages.2.8.mlp.act.scale", "encoder.stages.2.8.mlp.act.bias", "encoder.stages.2.8.mlp.fc2.weight", "encoder.stages.2.8.res_scale2.scale", "encoder.stages.2.9.norm1.weight", "encoder.stages.2.9.token_mixer.qkv.weight", "encoder.stages.2.9.token_mixer.proj.weight", "encoder.stages.2.9.res_scale1.scale", "encoder.stages.2.9.norm2.weight", "encoder.stages.2.9.mlp.fc1.weight", "encoder.stages.2.9.mlp.act.scale", "encoder.stages.2.9.mlp.act.bias", "encoder.stages.2.9.mlp.fc2.weight", "encoder.stages.2.9.res_scale2.scale", "encoder.stages.2.10.norm1.weight", "encoder.stages.2.10.token_mixer.qkv.weight", "encoder.stages.2.10.token_mixer.proj.weight", "encoder.stages.2.10.res_scale1.scale", "encoder.stages.2.10.norm2.weight", "encoder.stages.2.10.mlp.fc1.weight", "encoder.stages.2.10.mlp.act.scale", "encoder.stages.2.10.mlp.act.bias", "encoder.stages.2.10.mlp.fc2.weight", "encoder.stages.2.10.res_scale2.scale", "encoder.stages.2.11.norm1.weight", "encoder.stages.2.11.token_mixer.qkv.weight", "encoder.stages.2.11.token_mixer.proj.weight", "encoder.stages.2.11.res_scale1.scale", "encoder.stages.2.11.norm2.weight", "encoder.stages.2.11.mlp.fc1.weight", "encoder.stages.2.11.mlp.act.scale", "encoder.stages.2.11.mlp.act.bias", "encoder.stages.2.11.mlp.fc2.weight", "encoder.stages.2.11.res_scale2.scale", "encoder.stages.2.12.norm1.weight", "encoder.stages.2.12.token_mixer.qkv.weight", "encoder.stages.2.12.token_mixer.proj.weight", "encoder.stages.2.12.res_scale1.scale", "encoder.stages.2.12.norm2.weight", "encoder.stages.2.12.mlp.fc1.weight", "encoder.stages.2.12.mlp.act.scale", "encoder.stages.2.12.mlp.act.bias", "encoder.stages.2.12.mlp.fc2.weight", "encoder.stages.2.12.res_scale2.scale", "encoder.stages.2.13.norm1.weight", "encoder.stages.2.13.token_mixer.qkv.weight", "encoder.stages.2.13.token_mixer.proj.weight", "encoder.stages.2.13.res_scale1.scale", "encoder.stages.2.13.norm2.weight", "encoder.stages.2.13.mlp.fc1.weight", "encoder.stages.2.13.mlp.act.scale", "encoder.stages.2.13.mlp.act.bias", "encoder.stages.2.13.mlp.fc2.weight", "encoder.stages.2.13.res_scale2.scale", "encoder.stages.2.14.norm1.weight", "encoder.stages.2.14.token_mixer.qkv.weight", "encoder.stages.2.14.token_mixer.proj.weight", "encoder.stages.2.14.res_scale1.scale", "encoder.stages.2.14.norm2.weight", "encoder.stages.2.14.mlp.fc1.weight", "encoder.stages.2.14.mlp.act.scale", "encoder.stages.2.14.mlp.act.bias", "encoder.stages.2.14.mlp.fc2.weight", "encoder.stages.2.14.res_scale2.scale", "encoder.stages.2.15.norm1.weight", "encoder.stages.2.15.token_mixer.qkv.weight", "encoder.stages.2.15.token_mixer.proj.weight", "encoder.stages.2.15.res_scale1.scale", "encoder.stages.2.15.norm2.weight", "encoder.stages.2.15.mlp.fc1.weight", "encoder.stages.2.15.mlp.act.scale", "encoder.stages.2.15.mlp.act.bias", "encoder.stages.2.15.mlp.fc2.weight", "encoder.stages.2.15.res_scale2.scale", "encoder.stages.2.16.norm1.weight", "encoder.stages.2.16.token_mixer.qkv.weight", "encoder.stages.2.16.token_mixer.proj.weight", "encoder.stages.2.16.res_scale1.scale", "encoder.stages.2.16.norm2.weight", "encoder.stages.2.16.mlp.fc1.weight", "encoder.stages.2.16.mlp.act.scale", "encoder.stages.2.16.mlp.act.bias", "encoder.stages.2.16.mlp.fc2.weight", "encoder.stages.2.16.res_scale2.scale", "encoder.stages.2.17.norm1.weight", "encoder.stages.2.17.token_mixer.qkv.weight", "encoder.stages.2.17.token_mixer.proj.weight", "encoder.stages.2.17.res_scale1.scale", "encoder.stages.2.17.norm2.weight", "encoder.stages.2.17.mlp.fc1.weight", "encoder.stages.2.17.mlp.act.scale", "encoder.stages.2.17.mlp.act.bias", "encoder.stages.2.17.mlp.fc2.weight", "encoder.stages.2.17.res_scale2.scale", "decoder.convs.conv0_0.conv1.0.weight", "decoder.convs.conv0_0.conv2.0.weight", "decoder.convs.conv1_0.conv1.0.weight", "decoder.convs.conv1_0.conv2.0.weight", "decoder.convs.conv2_0.conv1.0.weight", "decoder.convs.conv2_0.conv2.0.weight", "decoder.convs.conv3_0.conv1.0.weight", "decoder.convs.conv3_0.conv2.0.weight", "decoder.convs.conv0_1.conv1.0.weight", "decoder.convs.conv0_1.conv2.0.weight", "decoder.convs.conv1_1.conv1.0.weight", "decoder.convs.conv1_1.conv2.0.weight", "decoder.convs.conv2_1.conv1.0.weight", "decoder.convs.conv2_1.conv2.0.weight", "decoder.convs.conv0_2.conv1.0.weight", "decoder.convs.conv0_2.conv2.0.weight", "decoder.convs.conv1_2.conv1.0.weight", "decoder.convs.conv1_2.conv2.0.weight", "decoder.convs.conv0_3.conv1.0.weight", "decoder.convs.conv0_3.conv2.0.weight", "head.final.0.weight", "head.final.0.bias".
Unexpected key(s) in state_dict: "downsample_layers.0.conv.weight", "downsample_layers.0.conv.bias", "downsample_layers.0.post_norm.weight", "downsample_layers.1.pre_norm.weight", "downsample_layers.1.conv.weight", "downsample_layers.1.conv.bias", "downsample_layers.2.pre_norm.weight", "downsample_layers.2.conv.weight", "downsample_layers.2.conv.bias", "downsample_layers.3.pre_norm.weight", "downsample_layers.3.conv.weight", "downsample_layers.3.conv.bias", "stages.0.0.norm1.weight", "stages.0.0.token_mixer.pwconv1.weight", "stages.0.0.token_mixer.act1.scale", "stages.0.0.token_mixer.act1.bias", "stages.0.0.token_mixer.dwconv.weight", "stages.0.0.token_mixer.pwconv2.weight", "stages.0.0.norm2.weight", "stages.0.0.mlp.fc1.weight", "stages.0.0.mlp.act.scale", "stages.0.0.mlp.act.bias", "stages.0.0.mlp.fc2.weight", "stages.0.1.norm1.weight", "stages.0.1.token_mixer.pwconv1.weight", "stages.0.1.token_mixer.act1.scale", "stages.0.1.token_mixer.act1.bias", "stages.0.1.token_mixer.dwconv.weight", "stages.0.1.token_mixer.pwconv2.weight", "stages.0.1.norm2.weight", "stages.0.1.mlp.fc1.weight", "stages.0.1.mlp.act.scale", "stages.0.1.mlp.act.bias", "stages.0.1.mlp.fc2.weight", "stages.0.2.norm1.weight", "stages.0.2.token_mixer.pwconv1.weight", "stages.0.2.token_mixer.act1.scale", "stages.0.2.token_mixer.act1.bias", "stages.0.2.token_mixer.dwconv.weight", "stages.0.2.token_mixer.pwconv2.weight", "stages.0.2.norm2.weight", "stages.0.2.mlp.fc1.weight", "stages.0.2.mlp.act.scale", "stages.0.2.mlp.act.bias", "stages.0.2.mlp.fc2.weight", "stages.1.0.norm1.weight", "stages.1.0.token_mixer.pwconv1.weight", "stages.1.0.token_mixer.act1.scale", "stages.1.0.token_mixer.act1.bias", "stages.1.0.token_mixer.dwconv.weight", "stages.1.0.token_mixer.pwconv2.weight", "stages.1.0.norm2.weight", "stages.1.0.mlp.fc1.weight", "stages.1.0.mlp.act.scale", "stages.1.0.mlp.act.bias", "stages.1.0.mlp.fc2.weight", "stages.1.1.norm1.weight", "stages.1.1.token_mixer.pwconv1.weight", "stages.1.1.token_mixer.act1.scale", "stages.1.1.token_mixer.act1.bias", "stages.1.1.token_mixer.dwconv.weight", "stages.1.1.token_mixer.pwconv2.weight", "stages.1.1.norm2.weight", "stages.1.1.mlp.fc1.weight", "stages.1.1.mlp.act.scale", "stages.1.1.mlp.act.bias", "stages.1.1.mlp.fc2.weight", "stages.1.2.norm1.weight", "stages.1.2.token_mixer.pwconv1.weight", "stages.1.2.token_mixer.act1.scale", "stages.1.2.token_mixer.act1.bias", "stages.1.2.token_mixer.dwconv.weight", "stages.1.2.token_mixer.pwconv2.weight", "stages.1.2.norm2.weight", "stages.1.2.mlp.fc1.weight", "stages.1.2.mlp.act.scale", "stages.1.2.mlp.act.bias", "stages.1.2.mlp.fc2.weight", "stages.1.3.norm1.weight", "stages.1.3.token_mixer.pwconv1.weight", "stages.1.3.token_mixer.act1.scale", "stages.1.3.token_mixer.act1.bias", "stages.1.3.token_mixer.dwconv.weight", "stages.1.3.token_mixer.pwconv2.weight", "stages.1.3.norm2.weight", "stages.1.3.mlp.fc1.weight", "stages.1.3.mlp.act.scale", "stages.1.3.mlp.act.bias", "stages.1.3.mlp.fc2.weight", "stages.1.4.norm1.weight", "stages.1.4.token_mixer.pwconv1.weight", "stages.1.4.token_mixer.act1.scale", "stages.1.4.token_mixer.act1.bias", "stages.1.4.token_mixer.dwconv.weight", "stages.1.4.token_mixer.pwconv2.weight", "stages.1.4.norm2.weight", "stages.1.4.mlp.fc1.weight", "stages.1.4.mlp.act.scale", "stages.1.4.mlp.act.bias", "stages.1.4.mlp.fc2.weight", "stages.1.5.norm1.weight", "stages.1.5.token_mixer.pwconv1.weight", "stages.1.5.token_mixer.act1.scale", "stages.1.5.token_mixer.act1.bias", "stages.1.5.token_mixer.dwconv.weight", "stages.1.5.token_mixer.pwconv2.weight", "stages.1.5.norm2.weight", "stages.1.5.mlp.fc1.weight", "stages.1.5.mlp.act.scale", "stages.1.5.mlp.act.bias", "stages.1.5.mlp.fc2.weight", "stages.1.6.norm1.weight", "stages.1.6.token_mixer.pwconv1.weight", "stages.1.6.token_mixer.act1.scale", "stages.1.6.token_mixer.act1.bias", "stages.1.6.token_mixer.dwconv.weight", "stages.1.6.token_mixer.pwconv2.weight", "stages.1.6.norm2.weight", "stages.1.6.mlp.fc1.weight", "stages.1.6.mlp.act.scale", "stages.1.6.mlp.act.bias", "stages.1.6.mlp.fc2.weight", "stages.1.7.norm1.weight", "stages.1.7.token_mixer.pwconv1.weight", "stages.1.7.token_mixer.act1.scale", "stages.1.7.token_mixer.act1.bias", "stages.1.7.token_mixer.dwconv.weight", "stages.1.7.token_mixer.pwconv2.weight", "stages.1.7.norm2.weight", "stages.1.7.mlp.fc1.weight", "stages.1.7.mlp.act.scale", "stages.1.7.mlp.act.bias", "stages.1.7.mlp.fc2.weight", "stages.1.8.norm1.weight", "stages.1.8.token_mixer.pwconv1.weight", "stages.1.8.token_mixer.act1.scale", "stages.1.8.token_mixer.act1.bias", "stages.1.8.token_mixer.dwconv.weight", "stages.1.8.token_mixer.pwconv2.weight", "stages.1.8.norm2.weight", "stages.1.8.mlp.fc1.weight", "stages.1.8.mlp.act.scale", "stages.1.8.mlp.act.bias", "stages.1.8.mlp.fc2.weight", "stages.1.9.norm1.weight", "stages.1.9.token_mixer.pwconv1.weight", "stages.1.9.token_mixer.act1.scale", "stages.1.9.token_mixer.act1.bias", "stages.1.9.token_mixer.dwconv.weight", "stages.1.9.token_mixer.pwconv2.weight", "stages.1.9.norm2.weight", "stages.1.9.mlp.fc1.weight", "stages.1.9.mlp.act.scale", "stages.1.9.mlp.act.bias", "stages.1.9.mlp.fc2.weight", "stages.1.10.norm1.weight", "stages.1.10.token_mixer.pwconv1.weight", "stages.1.10.token_mixer.act1.scale", "stages.1.10.token_mixer.act1.bias", "stages.1.10.token_mixer.dwconv.weight", "stages.1.10.token_mixer.pwconv2.weight", "stages.1.10.norm2.weight", "stages.1.10.mlp.fc1.weight", "stages.1.10.mlp.act.scale", "stages.1.10.mlp.act.bias", "stages.1.10.mlp.fc2.weight", "stages.1.11.norm1.weight", "stages.1.11.token_mixer.pwconv1.weight", "stages.1.11.token_mixer.act1.scale", "stages.1.11.token_mixer.act1.bias", "stages.1.11.token_mixer.dwconv.weight", "stages.1.11.token_mixer.pwconv2.weight", "stages.1.11.norm2.weight", "stages.1.11.mlp.fc1.weight", "stages.1.11.mlp.act.scale", "stages.1.11.mlp.act.bias", "stages.1.11.mlp.fc2.weight", "stages.2.0.norm1.weight", "stages.2.0.token_mixer.qkv.weight", "stages.2.0.token_mixer.proj.weight", "stages.2.0.res_scale1.scale", "stages.2.0.norm2.weight", "stages.2.0.mlp.fc1.weight", "stages.2.0.mlp.act.scale", "stages.2.0.mlp.act.bias", "stages.2.0.mlp.fc2.weight", "stages.2.0.res_scale2.scale", "stages.2.1.norm1.weight", "stages.2.1.token_mixer.qkv.weight", "stages.2.1.token_mixer.proj.weight", "stages.2.1.res_scale1.scale", "stages.2.1.norm2.weight", "stages.2.1.mlp.fc1.weight", "stages.2.1.mlp.act.scale", "stages.2.1.mlp.act.bias", "stages.2.1.mlp.fc2.weight", "stages.2.1.res_scale2.scale", "stages.2.2.norm1.weight", "stages.2.2.token_mixer.qkv.weight", "stages.2.2.token_mixer.proj.weight", "stages.2.2.res_scale1.scale", "stages.2.2.norm2.weight", "stages.2.2.mlp.fc1.weight", "stages.2.2.mlp.act.scale", "stages.2.2.mlp.act.bias", "stages.2.2.mlp.fc2.weight", "stages.2.2.res_scale2.scale", "stages.2.3.norm1.weight", "stages.2.3.token_mixer.qkv.weight", "stages.2.3.token_mixer.proj.weight", "stages.2.3.res_scale1.scale", "stages.2.3.norm2.weight", "stages.2.3.mlp.fc1.weight", "stages.2.3.mlp.act.scale", "stages.2.3.mlp.act.bias", "stages.2.3.mlp.fc2.weight", "stages.2.3.res_scale2.scale", "stages.2.4.norm1.weight", "stages.2.4.token_mixer.qkv.weight", "stages.2.4.token_mixer.proj.weight", "stages.2.4.res_scale1.scale", "stages.2.4.norm2.weight", "stages.2.4.mlp.fc1.weight", "stages.2.4.mlp.act.scale", "stages.2.4.mlp.act.bias", "stages.2.4.mlp.fc2.weight", "stages.2.4.res_scale2.scale", "stages.2.5.norm1.weight", "stages.2.5.token_mixer.qkv.weight", "stages.2.5.token_mixer.proj.weight", "stages.2.5.res_scale1.scale", "stages.2.5.norm2.weight", "stages.2.5.mlp.fc1.weight", "stages.2.5.mlp.act.scale", "stages.2.5.mlp.act.bias", "stages.2.5.mlp.fc2.weight", "stages.2.5.res_scale2.scale", "stages.2.6.norm1.weight", "stages.2.6.token_mixer.qkv.weight", "stages.2.6.token_mixer.proj.weight", "stages.2.6.res_scale1.scale", "stages.2.6.norm2.weight", "stages.2.6.mlp.fc1.weight", "stages.2.6.mlp.act.scale", "stages.2.6.mlp.act.bias", "stages.2.6.mlp.fc2.weight", "stages.2.6.res_scale2.scale", "stages.2.7.norm1.weight", "stages.2.7.token_mixer.qkv.weight", "stages.2.7.token_mixer.proj.weight", "stages.2.7.res_scale1.scale", "stages.2.7.norm2.weight", "stages.2.7.mlp.fc1.weight", "stages.2.7.mlp.act.scale", "stages.2.7.mlp.act.bias", "stages.2.7.mlp.fc2.weight", "stages.2.7.res_scale2.scale", "stages.2.8.norm1.weight", "stages.2.8.token_mixer.qkv.weight", "stages.2.8.token_mixer.proj.weight", "stages.2.8.res_scale1.scale", "stages.2.8.norm2.weight", "stages.2.8.mlp.fc1.weight", "stages.2.8.mlp.act.scale", "stages.2.8.mlp.act.bias", "stages.2.8.mlp.fc2.weight", "stages.2.8.res_scale2.scale", "stages.2.9.norm1.weight", "stages.2.9.token_mixer.qkv.weight", "stages.2.9.token_mixer.proj.weight", "stages.2.9.res_scale1.scale", "stages.2.9.norm2.weight", "stages.2.9.mlp.fc1.weight", "stages.2.9.mlp.act.scale", "stages.2.9.mlp.act.bias", "stages.2.9.mlp.fc2.weight", "stages.2.9.res_scale2.scale", "stages.2.10.norm1.weight", "stages.2.10.token_mixer.qkv.weight", "stages.2.10.token_mixer.proj.weight", "stages.2.10.res_scale1.scale", "stages.2.10.norm2.weight", "stages.2.10.mlp.fc1.weight", "stages.2.10.mlp.act.scale", "stages.2.10.mlp.act.bias", "stages.2.10.mlp.fc2.weight", "stages.2.10.res_scale2.scale", "stages.2.11.norm1.weight", "stages.2.11.token_mixer.qkv.weight", "stages.2.11.token_mixer.proj.weight", "stages.2.11.res_scale1.scale", "stages.2.11.norm2.weight", "stages.2.11.mlp.fc1.weight", "stages.2.11.mlp.act.scale", "stages.2.11.mlp.act.bias", "stages.2.11.mlp.fc2.weight", "stages.2.11.res_scale2.scale", "stages.2.12.norm1.weight", "stages.2.12.token_mixer.qkv.weight", "stages.2.12.token_mixer.proj.weight", "stages.2.12.res_scale1.scale", "stages.2.12.norm2.weight", "stages.2.12.mlp.fc1.weight", "stages.2.12.mlp.act.scale", "stages.2.12.mlp.act.bias", "stages.2.12.mlp.fc2.weight", "stages.2.12.res_scale2.scale", "stages.2.13.norm1.weight", "stages.2.13.token_mixer.qkv.weight", "stages.2.13.token_mixer.proj.weight", "stages.2.13.res_scale1.scale", "stages.2.13.norm2.weight", "stages.2.13.mlp.fc1.weight", "stages.2.13.mlp.act.scale", "stages.2.13.mlp.act.bias", "stages.2.13.mlp.fc2.weight", "stages.2.13.res_scale2.scale", "stages.2.14.norm1.weight", "stages.2.14.token_mixer.qkv.weight", "stages.2.14.token_mixer.proj.weight", "stages.2.14.res_scale1.scale", "stages.2.14.norm2.weight", "stages.2.14.mlp.fc1.weight", "stages.2.14.mlp.act.scale", "stages.2.14.mlp.act.bias", "stages.2.14.mlp.fc2.weight", "stages.2.14.res_scale2.scale", "stages.2.15.norm1.weight", "stages.2.15.token_mixer.qkv.weight", "stages.2.15.token_mixer.proj.weight", "stages.2.15.res_scale1.scale", "stages.2.15.norm2.weight", "stages.2.15.mlp.fc1.weight", "stages.2.15.mlp.act.scale", "stages.2.15.mlp.act.bias", "stages.2.15.mlp.fc2.weight", "stages.2.15.res_scale2.scale", "stages.2.16.norm1.weight", "stages.2.16.token_mixer.qkv.weight", "stages.2.16.token_mixer.proj.weight", "stages.2.16.res_scale1.scale", "stages.2.16.norm2.weight", "stages.2.16.mlp.fc1.weight", "stages.2.16.mlp.act.scale", "stages.2.16.mlp.act.bias", "stages.2.16.mlp.fc2.weight", "stages.2.16.res_scale2.scale", "stages.2.17.norm1.weight", "stages.2.17.token_mixer.qkv.weight", "stages.2.17.token_mixer.proj.weight", "stages.2.17.res_scale1.scale", "stages.2.17.norm2.weight", "stages.2.17.mlp.fc1.weight", "stages.2.17.mlp.act.scale", "stages.2.17.mlp.act.bias", "stages.2.17.mlp.fc2.weight", "stages.2.17.res_scale2.scale", "stages.3.0.norm1.weight", "stages.3.0.token_mixer.qkv.weight", "stages.3.0.token_mixer.proj.weight", "stages.3.0.res_scale1.scale", "stages.3.0.norm2.weight", "stages.3.0.mlp.fc1.weight", "stages.3.0.mlp.act.scale", "stages.3.0.mlp.act.bias", "stages.3.0.mlp.fc2.weight", "stages.3.0.res_scale2.scale", "stages.3.1.norm1.weight", "stages.3.1.token_mixer.qkv.weight", "stages.3.1.token_mixer.proj.weight", "stages.3.1.res_scale1.scale", "stages.3.1.norm2.weight", "stages.3.1.mlp.fc1.weight", "stages.3.1.mlp.act.scale", "stages.3.1.mlp.act.bias", "stages.3.1.mlp.fc2.weight", "stages.3.1.res_scale2.scale", "stages.3.2.norm1.weight", "stages.3.2.token_mixer.qkv.weight", "stages.3.2.token_mixer.proj.weight", "stages.3.2.res_scale1.scale", "stages.3.2.norm2.weight", "stages.3.2.mlp.fc1.weight", "stages.3.2.mlp.act.scale", "stages.3.2.mlp.act.bias", "stages.3.2.mlp.fc2.weight", "stages.3.2.res_scale2.scale", "norm.weight", "norm.bias", "head.fc1.weight", "head.fc1.bias", "head.norm.weight", "head.norm.bias", "head.fc2.weight", "head.fc2.bias".
This error is related to the command model.load_state_dict(ckpt)
.
Before running this script I downloaded the checkpoint you mentioned in README: https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth
I have seen that in main.py
script you load the model using checkpoint like this:
if args.resume is not None:
ckpt = torch.load(args.resume, map_location='cpu')['state_dict']
ckpt["encoder.conv2.0.weight"] = ckpt.pop("encoder.conv2.1.weight")
ckpt["encoder.conv2.0.bias"] = ckpt.pop("encoder.conv2.1.bias")
model.load_state_dict(ckpt)
print("load pretrained model, successfully!")
However, I cannot load the model this way as I'm getting error saying that there is no key 'state_dict' for ckpt object. Also, I cannot pop "encoder.conv2.1.weight" and "encoder.conv2.1.bias" since there are no such keys in ckpt object.
Here are my python libs versions inside virtual environment:
Ubuntu 22.04
Python 3.10.12
Package Version
------------------------ ----------
certifi 2024.8.30
charset-normalizer 3.4.0
filelock 3.16.1
fsspec 2024.10.0
huggingface-hub 0.26.0
idna 3.10
Jinja2 3.1.4
MarkupSafe 3.0.2
mpmath 1.3.0
networkx 3.4.1
numpy 2.1.2
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
packaging 24.1
pillow 11.0.0
pip 22.0.2
PyYAML 6.0.2
requests 2.32.3
safetensors 0.4.5
setuptools 59.6.0
sympy 1.13.1
timm 1.0.11
torch 2.5.0
torchvision 0.20.0
tqdm 4.66.5
triton 3.1.0
typing_extensions 4.12.2
urllib3 2.2.3
Could you please provide us with a separate inference script that we can easily run on an arbitrary input image using the pretrained weights you mentioned?
There is no problem with the code, because other people have successfully reproduced it according to the code I provided. I am currently on a week-long trip and my computer is not with me. You can try to solve the problem first. If it is not solved, I can get back to you next week.
---Original--- From: "Nikola @.> Date: Mon, Oct 21, 2024 11:41 AM To: @.>; Cc: @.**@.>; Subject: Re: [Li-yachuan/NBED] Inference script (Issue #3)
Here is my_inference.py script:
import torch import torchvision.transforms as transforms from PIL import Image import numpy as np import argparse import os # Import the model definition from model.basemodel import Basemodel def load_model(args, device): # Initialize the model model = Basemodel(encoder_name=args.encoder, decoder_name=args.decoder, head_name=args.head).to(device) # Load the pretrained weights if args.resume is not None: ckpt = torch.load(args.resume, weights_only=True, map_location=device) if 'state_dict' in ckpt: ckpt = ckpt['state_dict'] model.load_state_dict(ckpt) else: print("No pretrained weights provided. Using untrained model.") model.eval() return model def preprocess_image(image_path, device): # Define the image transformations transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load and preprocess the image image = Image.open(image_path).convert('RGB') image = transform(image).unsqueeze(0) # Add batch dimension image = image.to(device) return image def run_inference(model, input_image, device): with torch.no_grad(): output = model(input_image) if isinstance(output, (tuple, list)): output = output[0] # Process the output tensor to create an image output = output.squeeze().cpu().numpy() # Normalize the output to [0, 1] output = (output - output.min()) / (output.max() - output.min()) output_image = Image.fromarray((output * 255).astype(np.uint8)) return output_image def main(): parser = argparse.ArgumentParser(description='NBED Edge Detection Inference') parser.add_argument('--input', type=str, required=True, help='Path to the input image file') parser.add_argument('--output', type=str, default='output.png', help='Path to save the output image') parser.add_argument('--resume', type=str, default=None, help='Path to the pretrained model weights (.pth file)') parser.add_argument("--encoder", default="Dul-M36", help="Options: caformer-m36, Dul-M36") parser.add_argument("--decoder", default="unetp", help="Options: unet, unetp, default") parser.add_argument("--head", default="default", help="Options: default, aspp, atten, cofusion") args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the model model = load_model(args, device) # Preprocess the input image input_image = preprocess_image(args.input, device) # Run inference output_image = run_inference(model, input_image, device) # Save the output image output_image.save(args.output) print(f'Edge-detected image saved to {args.output}') if name == 'main': main() # Run command: python3 my_inference.py --input Imgs/test_assembly.png --output outputs/edge_output.png --resume model/caformer_m36_384_in21ft1k.pth
And this is the error message I get on model loading:
RuntimeError: Error(s) in loading state_dict for Basemodel: Missing key(s) in state_dict: "encoder.conv1.0.weight", "encoder.conv1.0.bias", "encoder.conv2.0.weight", "encoder.conv2.0.bias", "encoder.downsample_layers.0.conv.weight", "encoder.downsample_layers.0.conv.bias", "encoder.downsample_layers.0.post_norm.weight", "encoder.downsample_layers.1.pre_norm.weight", "encoder.downsample_layers.1.conv.weight", "encoder.downsample_layers.1.conv.bias", "encoder.downsample_layers.2.pre_norm.weight", "encoder.downsample_layers.2.conv.weight", "encoder.downsample_layers.2.conv.bias", "encoder.stages.0.0.norm1.weight", "encoder.stages.0.0.token_mixer.pwconv1.weight", "encoder.stages.0.0.token_mixer.act1.scale", "encoder.stages.0.0.token_mixer.act1.bias", "encoder.stages.0.0.token_mixer.dwconv.weight", "encoder.stages.0.0.token_mixer.pwconv2.weight", "encoder.stages.0.0.norm2.weight", "encoder.stages.0.0.mlp.fc1.weight", "encoder.stages.0.0.mlp.act.scale", "encoder.stages.0.0.mlp.act.bias", "encoder.stages.0.0.mlp.fc2.weight", "encoder.stages.0.1.norm1.weight", "encoder.stages.0.1.token_mixer.pwconv1.weight", "encoder.stages.0.1.token_mixer.act1.scale", "encoder.stages.0.1.token_mixer.act1.bias", "encoder.stages.0.1.token_mixer.dwconv.weight", "encoder.stages.0.1.token_mixer.pwconv2.weight", "encoder.stages.0.1.norm2.weight", "encoder.stages.0.1.mlp.fc1.weight", "encoder.stages.0.1.mlp.act.scale", "encoder.stages.0.1.mlp.act.bias", "encoder.stages.0.1.mlp.fc2.weight", "encoder.stages.0.2.norm1.weight", "encoder.stages.0.2.token_mixer.pwconv1.weight", "encoder.stages.0.2.token_mixer.act1.scale", "encoder.stages.0.2.token_mixer.act1.bias", "encoder.stages.0.2.token_mixer.dwconv.weight", "encoder.stages.0.2.token_mixer.pwconv2.weight", "encoder.stages.0.2.norm2.weight", "encoder.stages.0.2.mlp.fc1.weight", "encoder.stages.0.2.mlp.act.scale", "encoder.stages.0.2.mlp.act.bias", "encoder.stages.0.2.mlp.fc2.weight", "encoder.stages.1.0.norm1.weight", "encoder.stages.1.0.token_mixer.pwconv1.weight", "encoder.stages.1.0.token_mixer.act1.scale", "encoder.stages.1.0.token_mixer.act1.bias", "encoder.stages.1.0.token_mixer.dwconv.weight", "encoder.stages.1.0.token_mixer.pwconv2.weight", "encoder.stages.1.0.norm2.weight", "encoder.stages.1.0.mlp.fc1.weight", "encoder.stages.1.0.mlp.act.scale", "encoder.stages.1.0.mlp.act.bias", "encoder.stages.1.0.mlp.fc2.weight", "encoder.stages.1.1.norm1.weight", "encoder.stages.1.1.token_mixer.pwconv1.weight", "encoder.stages.1.1.token_mixer.act1.scale", "encoder.stages.1.1.token_mixer.act1.bias", "encoder.stages.1.1.token_mixer.dwconv.weight", "encoder.stages.1.1.token_mixer.pwconv2.weight", "encoder.stages.1.1.norm2.weight", "encoder.stages.1.1.mlp.fc1.weight", "encoder.stages.1.1.mlp.act.scale", "encoder.stages.1.1.mlp.act.bias", "encoder.stages.1.1.mlp.fc2.weight", "encoder.stages.1.2.norm1.weight", "encoder.stages.1.2.token_mixer.pwconv1.weight", "encoder.stages.1.2.token_mixer.act1.scale", "encoder.stages.1.2.token_mixer.act1.bias", "encoder.stages.1.2.token_mixer.dwconv.weight", "encoder.stages.1.2.token_mixer.pwconv2.weight", "encoder.stages.1.2.norm2.weight", "encoder.stages.1.2.mlp.fc1.weight", "encoder.stages.1.2.mlp.act.scale", "encoder.stages.1.2.mlp.act.bias", "encoder.stages.1.2.mlp.fc2.weight", "encoder.stages.1.3.norm1.weight", "encoder.stages.1.3.token_mixer.pwconv1.weight", "encoder.stages.1.3.token_mixer.act1.scale", "encoder.stages.1.3.token_mixer.act1.bias", "encoder.stages.1.3.token_mixer.dwconv.weight", "encoder.stages.1.3.token_mixer.pwconv2.weight", "encoder.stages.1.3.norm2.weight", "encoder.stages.1.3.mlp.fc1.weight", "encoder.stages.1.3.mlp.act.scale", "encoder.stages.1.3.mlp.act.bias", "encoder.stages.1.3.mlp.fc2.weight", "encoder.stages.1.4.norm1.weight", "encoder.stages.1.4.token_mixer.pwconv1.weight", "encoder.stages.1.4.token_mixer.act1.scale", "encoder.stages.1.4.token_mixer.act1.bias", "encoder.stages.1.4.token_mixer.dwconv.weight", "encoder.stages.1.4.token_mixer.pwconv2.weight", "encoder.stages.1.4.norm2.weight", "encoder.stages.1.4.mlp.fc1.weight", "encoder.stages.1.4.mlp.act.scale", "encoder.stages.1.4.mlp.act.bias", "encoder.stages.1.4.mlp.fc2.weight", "encoder.stages.1.5.norm1.weight", "encoder.stages.1.5.token_mixer.pwconv1.weight", "encoder.stages.1.5.token_mixer.act1.scale", "encoder.stages.1.5.token_mixer.act1.bias", "encoder.stages.1.5.token_mixer.dwconv.weight", "encoder.stages.1.5.token_mixer.pwconv2.weight", "encoder.stages.1.5.norm2.weight", "encoder.stages.1.5.mlp.fc1.weight", "encoder.stages.1.5.mlp.act.scale", "encoder.stages.1.5.mlp.act.bias", "encoder.stages.1.5.mlp.fc2.weight", "encoder.stages.1.6.norm1.weight", "encoder.stages.1.6.token_mixer.pwconv1.weight", "encoder.stages.1.6.token_mixer.act1.scale", "encoder.stages.1.6.token_mixer.act1.bias", "encoder.stages.1.6.token_mixer.dwconv.weight", "encoder.stages.1.6.token_mixer.pwconv2.weight", "encoder.stages.1.6.norm2.weight", "encoder.stages.1.6.mlp.fc1.weight", "encoder.stages.1.6.mlp.act.scale", "encoder.stages.1.6.mlp.act.bias", "encoder.stages.1.6.mlp.fc2.weight", "encoder.stages.1.7.norm1.weight", "encoder.stages.1.7.token_mixer.pwconv1.weight", "encoder.stages.1.7.token_mixer.act1.scale", "encoder.stages.1.7.token_mixer.act1.bias", "encoder.stages.1.7.token_mixer.dwconv.weight", "encoder.stages.1.7.token_mixer.pwconv2.weight", "encoder.stages.1.7.norm2.weight", "encoder.stages.1.7.mlp.fc1.weight", "encoder.stages.1.7.mlp.act.scale", "encoder.stages.1.7.mlp.act.bias", "encoder.stages.1.7.mlp.fc2.weight", "encoder.stages.1.8.norm1.weight", "encoder.stages.1.8.token_mixer.pwconv1.weight", "encoder.stages.1.8.token_mixer.act1.scale", "encoder.stages.1.8.token_mixer.act1.bias", "encoder.stages.1.8.token_mixer.dwconv.weight", "encoder.stages.1.8.token_mixer.pwconv2.weight", "encoder.stages.1.8.norm2.weight", "encoder.stages.1.8.mlp.fc1.weight", "encoder.stages.1.8.mlp.act.scale", "encoder.stages.1.8.mlp.act.bias", "encoder.stages.1.8.mlp.fc2.weight", "encoder.stages.1.9.norm1.weight", "encoder.stages.1.9.token_mixer.pwconv1.weight", "encoder.stages.1.9.token_mixer.act1.scale", "encoder.stages.1.9.token_mixer.act1.bias", "encoder.stages.1.9.token_mixer.dwconv.weight", "encoder.stages.1.9.token_mixer.pwconv2.weight", "encoder.stages.1.9.norm2.weight", "encoder.stages.1.9.mlp.fc1.weight", "encoder.stages.1.9.mlp.act.scale", "encoder.stages.1.9.mlp.act.bias", "encoder.stages.1.9.mlp.fc2.weight", "encoder.stages.1.10.norm1.weight", "encoder.stages.1.10.token_mixer.pwconv1.weight", "encoder.stages.1.10.token_mixer.act1.scale", "encoder.stages.1.10.token_mixer.act1.bias", "encoder.stages.1.10.token_mixer.dwconv.weight", "encoder.stages.1.10.token_mixer.pwconv2.weight", "encoder.stages.1.10.norm2.weight", "encoder.stages.1.10.mlp.fc1.weight", "encoder.stages.1.10.mlp.act.scale", "encoder.stages.1.10.mlp.act.bias", "encoder.stages.1.10.mlp.fc2.weight", "encoder.stages.1.11.norm1.weight", "encoder.stages.1.11.token_mixer.pwconv1.weight", "encoder.stages.1.11.token_mixer.act1.scale", "encoder.stages.1.11.token_mixer.act1.bias", "encoder.stages.1.11.token_mixer.dwconv.weight", "encoder.stages.1.11.token_mixer.pwconv2.weight", "encoder.stages.1.11.norm2.weight", "encoder.stages.1.11.mlp.fc1.weight", "encoder.stages.1.11.mlp.act.scale", "encoder.stages.1.11.mlp.act.bias", "encoder.stages.1.11.mlp.fc2.weight", "encoder.stages.2.0.norm1.weight", "encoder.stages.2.0.token_mixer.qkv.weight", "encoder.stages.2.0.token_mixer.proj.weight", "encoder.stages.2.0.res_scale1.scale", "encoder.stages.2.0.norm2.weight", "encoder.stages.2.0.mlp.fc1.weight", "encoder.stages.2.0.mlp.act.scale", "encoder.stages.2.0.mlp.act.bias", "encoder.stages.2.0.mlp.fc2.weight", "encoder.stages.2.0.res_scale2.scale", "encoder.stages.2.1.norm1.weight", "encoder.stages.2.1.token_mixer.qkv.weight", "encoder.stages.2.1.token_mixer.proj.weight", "encoder.stages.2.1.res_scale1.scale", "encoder.stages.2.1.norm2.weight", "encoder.stages.2.1.mlp.fc1.weight", "encoder.stages.2.1.mlp.act.scale", "encoder.stages.2.1.mlp.act.bias", "encoder.stages.2.1.mlp.fc2.weight", "encoder.stages.2.1.res_scale2.scale", "encoder.stages.2.2.norm1.weight", "encoder.stages.2.2.token_mixer.qkv.weight", "encoder.stages.2.2.token_mixer.proj.weight", "encoder.stages.2.2.res_scale1.scale", "encoder.stages.2.2.norm2.weight", "encoder.stages.2.2.mlp.fc1.weight", "encoder.stages.2.2.mlp.act.scale", "encoder.stages.2.2.mlp.act.bias", "encoder.stages.2.2.mlp.fc2.weight", "encoder.stages.2.2.res_scale2.scale", "encoder.stages.2.3.norm1.weight", "encoder.stages.2.3.token_mixer.qkv.weight", "encoder.stages.2.3.token_mixer.proj.weight", "encoder.stages.2.3.res_scale1.scale", "encoder.stages.2.3.norm2.weight", "encoder.stages.2.3.mlp.fc1.weight", "encoder.stages.2.3.mlp.act.scale", "encoder.stages.2.3.mlp.act.bias", "encoder.stages.2.3.mlp.fc2.weight", "encoder.stages.2.3.res_scale2.scale", "encoder.stages.2.4.norm1.weight", "encoder.stages.2.4.token_mixer.qkv.weight", "encoder.stages.2.4.token_mixer.proj.weight", "encoder.stages.2.4.res_scale1.scale", "encoder.stages.2.4.norm2.weight", "encoder.stages.2.4.mlp.fc1.weight", "encoder.stages.2.4.mlp.act.scale", "encoder.stages.2.4.mlp.act.bias", "encoder.stages.2.4.mlp.fc2.weight", "encoder.stages.2.4.res_scale2.scale", "encoder.stages.2.5.norm1.weight", "encoder.stages.2.5.token_mixer.qkv.weight", "encoder.stages.2.5.token_mixer.proj.weight", "encoder.stages.2.5.res_scale1.scale", "encoder.stages.2.5.norm2.weight", "encoder.stages.2.5.mlp.fc1.weight", "encoder.stages.2.5.mlp.act.scale", "encoder.stages.2.5.mlp.act.bias", "encoder.stages.2.5.mlp.fc2.weight", "encoder.stages.2.5.res_scale2.scale", "encoder.stages.2.6.norm1.weight", "encoder.stages.2.6.token_mixer.qkv.weight", "encoder.stages.2.6.token_mixer.proj.weight", "encoder.stages.2.6.res_scale1.scale", "encoder.stages.2.6.norm2.weight", "encoder.stages.2.6.mlp.fc1.weight", "encoder.stages.2.6.mlp.act.scale", "encoder.stages.2.6.mlp.act.bias", "encoder.stages.2.6.mlp.fc2.weight", "encoder.stages.2.6.res_scale2.scale", "encoder.stages.2.7.norm1.weight", "encoder.stages.2.7.token_mixer.qkv.weight", "encoder.stages.2.7.token_mixer.proj.weight", "encoder.stages.2.7.res_scale1.scale", "encoder.stages.2.7.norm2.weight", "encoder.stages.2.7.mlp.fc1.weight", "encoder.stages.2.7.mlp.act.scale", "encoder.stages.2.7.mlp.act.bias", "encoder.stages.2.7.mlp.fc2.weight", "encoder.stages.2.7.res_scale2.scale", "encoder.stages.2.8.norm1.weight", "encoder.stages.2.8.token_mixer.qkv.weight", "encoder.stages.2.8.token_mixer.proj.weight", "encoder.stages.2.8.res_scale1.scale", "encoder.stages.2.8.norm2.weight", "encoder.stages.2.8.mlp.fc1.weight", "encoder.stages.2.8.mlp.act.scale", "encoder.stages.2.8.mlp.act.bias", "encoder.stages.2.8.mlp.fc2.weight", "encoder.stages.2.8.res_scale2.scale", "encoder.stages.2.9.norm1.weight", "encoder.stages.2.9.token_mixer.qkv.weight", "encoder.stages.2.9.token_mixer.proj.weight", "encoder.stages.2.9.res_scale1.scale", "encoder.stages.2.9.norm2.weight", "encoder.stages.2.9.mlp.fc1.weight", "encoder.stages.2.9.mlp.act.scale", "encoder.stages.2.9.mlp.act.bias", "encoder.stages.2.9.mlp.fc2.weight", "encoder.stages.2.9.res_scale2.scale", "encoder.stages.2.10.norm1.weight", "encoder.stages.2.10.token_mixer.qkv.weight", "encoder.stages.2.10.token_mixer.proj.weight", "encoder.stages.2.10.res_scale1.scale", "encoder.stages.2.10.norm2.weight", "encoder.stages.2.10.mlp.fc1.weight", "encoder.stages.2.10.mlp.act.scale", "encoder.stages.2.10.mlp.act.bias", "encoder.stages.2.10.mlp.fc2.weight", "encoder.stages.2.10.res_scale2.scale", "encoder.stages.2.11.norm1.weight", "encoder.stages.2.11.token_mixer.qkv.weight", "encoder.stages.2.11.token_mixer.proj.weight", "encoder.stages.2.11.res_scale1.scale", "encoder.stages.2.11.norm2.weight", "encoder.stages.2.11.mlp.fc1.weight", "encoder.stages.2.11.mlp.act.scale", "encoder.stages.2.11.mlp.act.bias", "encoder.stages.2.11.mlp.fc2.weight", "encoder.stages.2.11.res_scale2.scale", "encoder.stages.2.12.norm1.weight", "encoder.stages.2.12.token_mixer.qkv.weight", "encoder.stages.2.12.token_mixer.proj.weight", "encoder.stages.2.12.res_scale1.scale", "encoder.stages.2.12.norm2.weight", "encoder.stages.2.12.mlp.fc1.weight", "encoder.stages.2.12.mlp.act.scale", "encoder.stages.2.12.mlp.act.bias", "encoder.stages.2.12.mlp.fc2.weight", "encoder.stages.2.12.res_scale2.scale", "encoder.stages.2.13.norm1.weight", "encoder.stages.2.13.token_mixer.qkv.weight", "encoder.stages.2.13.token_mixer.proj.weight", "encoder.stages.2.13.res_scale1.scale", "encoder.stages.2.13.norm2.weight", "encoder.stages.2.13.mlp.fc1.weight", "encoder.stages.2.13.mlp.act.scale", "encoder.stages.2.13.mlp.act.bias", "encoder.stages.2.13.mlp.fc2.weight", "encoder.stages.2.13.res_scale2.scale", "encoder.stages.2.14.norm1.weight", "encoder.stages.2.14.token_mixer.qkv.weight", "encoder.stages.2.14.token_mixer.proj.weight", "encoder.stages.2.14.res_scale1.scale", "encoder.stages.2.14.norm2.weight", "encoder.stages.2.14.mlp.fc1.weight", "encoder.stages.2.14.mlp.act.scale", "encoder.stages.2.14.mlp.act.bias", "encoder.stages.2.14.mlp.fc2.weight", "encoder.stages.2.14.res_scale2.scale", "encoder.stages.2.15.norm1.weight", "encoder.stages.2.15.token_mixer.qkv.weight", "encoder.stages.2.15.token_mixer.proj.weight", "encoder.stages.2.15.res_scale1.scale", "encoder.stages.2.15.norm2.weight", "encoder.stages.2.15.mlp.fc1.weight", "encoder.stages.2.15.mlp.act.scale", "encoder.stages.2.15.mlp.act.bias", "encoder.stages.2.15.mlp.fc2.weight", "encoder.stages.2.15.res_scale2.scale", "encoder.stages.2.16.norm1.weight", "encoder.stages.2.16.token_mixer.qkv.weight", "encoder.stages.2.16.token_mixer.proj.weight", "encoder.stages.2.16.res_scale1.scale", "encoder.stages.2.16.norm2.weight", "encoder.stages.2.16.mlp.fc1.weight", "encoder.stages.2.16.mlp.act.scale", "encoder.stages.2.16.mlp.act.bias", "encoder.stages.2.16.mlp.fc2.weight", "encoder.stages.2.16.res_scale2.scale", "encoder.stages.2.17.norm1.weight", "encoder.stages.2.17.token_mixer.qkv.weight", "encoder.stages.2.17.token_mixer.proj.weight", "encoder.stages.2.17.res_scale1.scale", "encoder.stages.2.17.norm2.weight", "encoder.stages.2.17.mlp.fc1.weight", "encoder.stages.2.17.mlp.act.scale", "encoder.stages.2.17.mlp.act.bias", "encoder.stages.2.17.mlp.fc2.weight", "encoder.stages.2.17.res_scale2.scale", "decoder.convs.conv0_0.conv1.0.weight", "decoder.convs.conv0_0.conv2.0.weight", "decoder.convs.conv1_0.conv1.0.weight", "decoder.convs.conv1_0.conv2.0.weight", "decoder.convs.conv2_0.conv1.0.weight", "decoder.convs.conv2_0.conv2.0.weight", "decoder.convs.conv3_0.conv1.0.weight", "decoder.convs.conv3_0.conv2.0.weight", "decoder.convs.conv0_1.conv1.0.weight", "decoder.convs.conv0_1.conv2.0.weight", "decoder.convs.conv1_1.conv1.0.weight", "decoder.convs.conv1_1.conv2.0.weight", "decoder.convs.conv2_1.conv1.0.weight", "decoder.convs.conv2_1.conv2.0.weight", "decoder.convs.conv0_2.conv1.0.weight", "decoder.convs.conv0_2.conv2.0.weight", "decoder.convs.conv1_2.conv1.0.weight", "decoder.convs.conv1_2.conv2.0.weight", "decoder.convs.conv0_3.conv1.0.weight", "decoder.convs.conv0_3.conv2.0.weight", "head.final.0.weight", "head.final.0.bias". Unexpected key(s) in state_dict: "downsample_layers.0.conv.weight", "downsample_layers.0.conv.bias", "downsample_layers.0.post_norm.weight", "downsample_layers.1.pre_norm.weight", "downsample_layers.1.conv.weight", "downsample_layers.1.conv.bias", "downsample_layers.2.pre_norm.weight", "downsample_layers.2.conv.weight", "downsample_layers.2.conv.bias", "downsample_layers.3.pre_norm.weight", "downsample_layers.3.conv.weight", "downsample_layers.3.conv.bias", "stages.0.0.norm1.weight", "stages.0.0.token_mixer.pwconv1.weight", "stages.0.0.token_mixer.act1.scale", "stages.0.0.token_mixer.act1.bias", "stages.0.0.token_mixer.dwconv.weight", "stages.0.0.token_mixer.pwconv2.weight", "stages.0.0.norm2.weight", "stages.0.0.mlp.fc1.weight", "stages.0.0.mlp.act.scale", "stages.0.0.mlp.act.bias", "stages.0.0.mlp.fc2.weight", "stages.0.1.norm1.weight", "stages.0.1.token_mixer.pwconv1.weight", "stages.0.1.token_mixer.act1.scale", "stages.0.1.token_mixer.act1.bias", "stages.0.1.token_mixer.dwconv.weight", "stages.0.1.token_mixer.pwconv2.weight", "stages.0.1.norm2.weight", "stages.0.1.mlp.fc1.weight", "stages.0.1.mlp.act.scale", "stages.0.1.mlp.act.bias", "stages.0.1.mlp.fc2.weight", "stages.0.2.norm1.weight", "stages.0.2.token_mixer.pwconv1.weight", "stages.0.2.token_mixer.act1.scale", "stages.0.2.token_mixer.act1.bias", "stages.0.2.token_mixer.dwconv.weight", "stages.0.2.token_mixer.pwconv2.weight", "stages.0.2.norm2.weight", "stages.0.2.mlp.fc1.weight", "stages.0.2.mlp.act.scale", "stages.0.2.mlp.act.bias", "stages.0.2.mlp.fc2.weight", "stages.1.0.norm1.weight", "stages.1.0.token_mixer.pwconv1.weight", "stages.1.0.token_mixer.act1.scale", "stages.1.0.token_mixer.act1.bias", "stages.1.0.token_mixer.dwconv.weight", "stages.1.0.token_mixer.pwconv2.weight", "stages.1.0.norm2.weight", "stages.1.0.mlp.fc1.weight", "stages.1.0.mlp.act.scale", "stages.1.0.mlp.act.bias", "stages.1.0.mlp.fc2.weight", "stages.1.1.norm1.weight", "stages.1.1.token_mixer.pwconv1.weight", "stages.1.1.token_mixer.act1.scale", "stages.1.1.token_mixer.act1.bias", "stages.1.1.token_mixer.dwconv.weight", "stages.1.1.token_mixer.pwconv2.weight", "stages.1.1.norm2.weight", "stages.1.1.mlp.fc1.weight", "stages.1.1.mlp.act.scale", "stages.1.1.mlp.act.bias", "stages.1.1.mlp.fc2.weight", "stages.1.2.norm1.weight", "stages.1.2.token_mixer.pwconv1.weight", "stages.1.2.token_mixer.act1.scale", "stages.1.2.token_mixer.act1.bias", "stages.1.2.token_mixer.dwconv.weight", "stages.1.2.token_mixer.pwconv2.weight", "stages.1.2.norm2.weight", "stages.1.2.mlp.fc1.weight", "stages.1.2.mlp.act.scale", "stages.1.2.mlp.act.bias", "stages.1.2.mlp.fc2.weight", "stages.1.3.norm1.weight", "stages.1.3.token_mixer.pwconv1.weight", "stages.1.3.token_mixer.act1.scale", "stages.1.3.token_mixer.act1.bias", "stages.1.3.token_mixer.dwconv.weight", "stages.1.3.token_mixer.pwconv2.weight", "stages.1.3.norm2.weight", "stages.1.3.mlp.fc1.weight", "stages.1.3.mlp.act.scale", "stages.1.3.mlp.act.bias", "stages.1.3.mlp.fc2.weight", "stages.1.4.norm1.weight", "stages.1.4.token_mixer.pwconv1.weight", "stages.1.4.token_mixer.act1.scale", "stages.1.4.token_mixer.act1.bias", "stages.1.4.token_mixer.dwconv.weight", "stages.1.4.token_mixer.pwconv2.weight", "stages.1.4.norm2.weight", "stages.1.4.mlp.fc1.weight", "stages.1.4.mlp.act.scale", "stages.1.4.mlp.act.bias", "stages.1.4.mlp.fc2.weight", "stages.1.5.norm1.weight", "stages.1.5.token_mixer.pwconv1.weight", "stages.1.5.token_mixer.act1.scale", "stages.1.5.token_mixer.act1.bias", "stages.1.5.token_mixer.dwconv.weight", "stages.1.5.token_mixer.pwconv2.weight", "stages.1.5.norm2.weight", "stages.1.5.mlp.fc1.weight", "stages.1.5.mlp.act.scale", "stages.1.5.mlp.act.bias", "stages.1.5.mlp.fc2.weight", "stages.1.6.norm1.weight", "stages.1.6.token_mixer.pwconv1.weight", "stages.1.6.token_mixer.act1.scale", "stages.1.6.token_mixer.act1.bias", "stages.1.6.token_mixer.dwconv.weight", "stages.1.6.token_mixer.pwconv2.weight", "stages.1.6.norm2.weight", "stages.1.6.mlp.fc1.weight", "stages.1.6.mlp.act.scale", "stages.1.6.mlp.act.bias", "stages.1.6.mlp.fc2.weight", "stages.1.7.norm1.weight", "stages.1.7.token_mixer.pwconv1.weight", "stages.1.7.token_mixer.act1.scale", "stages.1.7.token_mixer.act1.bias", "stages.1.7.token_mixer.dwconv.weight", "stages.1.7.token_mixer.pwconv2.weight", "stages.1.7.norm2.weight", "stages.1.7.mlp.fc1.weight", "stages.1.7.mlp.act.scale", "stages.1.7.mlp.act.bias", "stages.1.7.mlp.fc2.weight", "stages.1.8.norm1.weight", "stages.1.8.token_mixer.pwconv1.weight", "stages.1.8.token_mixer.act1.scale", "stages.1.8.token_mixer.act1.bias", "stages.1.8.token_mixer.dwconv.weight", "stages.1.8.token_mixer.pwconv2.weight", "stages.1.8.norm2.weight", "stages.1.8.mlp.fc1.weight", "stages.1.8.mlp.act.scale", "stages.1.8.mlp.act.bias", "stages.1.8.mlp.fc2.weight", "stages.1.9.norm1.weight", "stages.1.9.token_mixer.pwconv1.weight", "stages.1.9.token_mixer.act1.scale", "stages.1.9.token_mixer.act1.bias", "stages.1.9.token_mixer.dwconv.weight", "stages.1.9.token_mixer.pwconv2.weight", "stages.1.9.norm2.weight", "stages.1.9.mlp.fc1.weight", "stages.1.9.mlp.act.scale", "stages.1.9.mlp.act.bias", "stages.1.9.mlp.fc2.weight", "stages.1.10.norm1.weight", "stages.1.10.token_mixer.pwconv1.weight", "stages.1.10.token_mixer.act1.scale", "stages.1.10.token_mixer.act1.bias", "stages.1.10.token_mixer.dwconv.weight", "stages.1.10.token_mixer.pwconv2.weight", "stages.1.10.norm2.weight", "stages.1.10.mlp.fc1.weight", "stages.1.10.mlp.act.scale", "stages.1.10.mlp.act.bias", "stages.1.10.mlp.fc2.weight", "stages.1.11.norm1.weight", "stages.1.11.token_mixer.pwconv1.weight", "stages.1.11.token_mixer.act1.scale", "stages.1.11.token_mixer.act1.bias", "stages.1.11.token_mixer.dwconv.weight", "stages.1.11.token_mixer.pwconv2.weight", "stages.1.11.norm2.weight", "stages.1.11.mlp.fc1.weight", "stages.1.11.mlp.act.scale", "stages.1.11.mlp.act.bias", "stages.1.11.mlp.fc2.weight", "stages.2.0.norm1.weight", "stages.2.0.token_mixer.qkv.weight", "stages.2.0.token_mixer.proj.weight", "stages.2.0.res_scale1.scale", "stages.2.0.norm2.weight", "stages.2.0.mlp.fc1.weight", "stages.2.0.mlp.act.scale", "stages.2.0.mlp.act.bias", "stages.2.0.mlp.fc2.weight", "stages.2.0.res_scale2.scale", "stages.2.1.norm1.weight", "stages.2.1.token_mixer.qkv.weight", "stages.2.1.token_mixer.proj.weight", "stages.2.1.res_scale1.scale", "stages.2.1.norm2.weight", "stages.2.1.mlp.fc1.weight", "stages.2.1.mlp.act.scale", "stages.2.1.mlp.act.bias", "stages.2.1.mlp.fc2.weight", "stages.2.1.res_scale2.scale", "stages.2.2.norm1.weight", "stages.2.2.token_mixer.qkv.weight", "stages.2.2.token_mixer.proj.weight", "stages.2.2.res_scale1.scale", "stages.2.2.norm2.weight", "stages.2.2.mlp.fc1.weight", "stages.2.2.mlp.act.scale", "stages.2.2.mlp.act.bias", "stages.2.2.mlp.fc2.weight", "stages.2.2.res_scale2.scale", "stages.2.3.norm1.weight", "stages.2.3.token_mixer.qkv.weight", "stages.2.3.token_mixer.proj.weight", "stages.2.3.res_scale1.scale", "stages.2.3.norm2.weight", "stages.2.3.mlp.fc1.weight", "stages.2.3.mlp.act.scale", "stages.2.3.mlp.act.bias", "stages.2.3.mlp.fc2.weight", "stages.2.3.res_scale2.scale", "stages.2.4.norm1.weight", "stages.2.4.token_mixer.qkv.weight", "stages.2.4.token_mixer.proj.weight", "stages.2.4.res_scale1.scale", "stages.2.4.norm2.weight", "stages.2.4.mlp.fc1.weight", "stages.2.4.mlp.act.scale", "stages.2.4.mlp.act.bias", "stages.2.4.mlp.fc2.weight", "stages.2.4.res_scale2.scale", "stages.2.5.norm1.weight", "stages.2.5.token_mixer.qkv.weight", "stages.2.5.token_mixer.proj.weight", "stages.2.5.res_scale1.scale", "stages.2.5.norm2.weight", "stages.2.5.mlp.fc1.weight", "stages.2.5.mlp.act.scale", "stages.2.5.mlp.act.bias", "stages.2.5.mlp.fc2.weight", "stages.2.5.res_scale2.scale", "stages.2.6.norm1.weight", "stages.2.6.token_mixer.qkv.weight", "stages.2.6.token_mixer.proj.weight", "stages.2.6.res_scale1.scale", "stages.2.6.norm2.weight", "stages.2.6.mlp.fc1.weight", "stages.2.6.mlp.act.scale", "stages.2.6.mlp.act.bias", "stages.2.6.mlp.fc2.weight", "stages.2.6.res_scale2.scale", "stages.2.7.norm1.weight", "stages.2.7.token_mixer.qkv.weight", "stages.2.7.token_mixer.proj.weight", "stages.2.7.res_scale1.scale", "stages.2.7.norm2.weight", "stages.2.7.mlp.fc1.weight", "stages.2.7.mlp.act.scale", "stages.2.7.mlp.act.bias", "stages.2.7.mlp.fc2.weight", "stages.2.7.res_scale2.scale", "stages.2.8.norm1.weight", "stages.2.8.token_mixer.qkv.weight", "stages.2.8.token_mixer.proj.weight", "stages.2.8.res_scale1.scale", "stages.2.8.norm2.weight", "stages.2.8.mlp.fc1.weight", "stages.2.8.mlp.act.scale", "stages.2.8.mlp.act.bias", "stages.2.8.mlp.fc2.weight", "stages.2.8.res_scale2.scale", "stages.2.9.norm1.weight", "stages.2.9.token_mixer.qkv.weight", "stages.2.9.token_mixer.proj.weight", "stages.2.9.res_scale1.scale", "stages.2.9.norm2.weight", "stages.2.9.mlp.fc1.weight", "stages.2.9.mlp.act.scale", "stages.2.9.mlp.act.bias", "stages.2.9.mlp.fc2.weight", "stages.2.9.res_scale2.scale", "stages.2.10.norm1.weight", "stages.2.10.token_mixer.qkv.weight", "stages.2.10.token_mixer.proj.weight", "stages.2.10.res_scale1.scale", "stages.2.10.norm2.weight", "stages.2.10.mlp.fc1.weight", "stages.2.10.mlp.act.scale", "stages.2.10.mlp.act.bias", "stages.2.10.mlp.fc2.weight", "stages.2.10.res_scale2.scale", "stages.2.11.norm1.weight", "stages.2.11.token_mixer.qkv.weight", "stages.2.11.token_mixer.proj.weight", "stages.2.11.res_scale1.scale", "stages.2.11.norm2.weight", "stages.2.11.mlp.fc1.weight", "stages.2.11.mlp.act.scale", "stages.2.11.mlp.act.bias", "stages.2.11.mlp.fc2.weight", "stages.2.11.res_scale2.scale", "stages.2.12.norm1.weight", "stages.2.12.token_mixer.qkv.weight", "stages.2.12.token_mixer.proj.weight", "stages.2.12.res_scale1.scale", "stages.2.12.norm2.weight", "stages.2.12.mlp.fc1.weight", "stages.2.12.mlp.act.scale", "stages.2.12.mlp.act.bias", "stages.2.12.mlp.fc2.weight", "stages.2.12.res_scale2.scale", "stages.2.13.norm1.weight", "stages.2.13.token_mixer.qkv.weight", "stages.2.13.token_mixer.proj.weight", "stages.2.13.res_scale1.scale", "stages.2.13.norm2.weight", "stages.2.13.mlp.fc1.weight", "stages.2.13.mlp.act.scale", "stages.2.13.mlp.act.bias", "stages.2.13.mlp.fc2.weight", "stages.2.13.res_scale2.scale", "stages.2.14.norm1.weight", "stages.2.14.token_mixer.qkv.weight", "stages.2.14.token_mixer.proj.weight", "stages.2.14.res_scale1.scale", "stages.2.14.norm2.weight", "stages.2.14.mlp.fc1.weight", "stages.2.14.mlp.act.scale", "stages.2.14.mlp.act.bias", "stages.2.14.mlp.fc2.weight", "stages.2.14.res_scale2.scale", "stages.2.15.norm1.weight", "stages.2.15.token_mixer.qkv.weight", "stages.2.15.token_mixer.proj.weight", "stages.2.15.res_scale1.scale", "stages.2.15.norm2.weight", "stages.2.15.mlp.fc1.weight", "stages.2.15.mlp.act.scale", "stages.2.15.mlp.act.bias", "stages.2.15.mlp.fc2.weight", "stages.2.15.res_scale2.scale", "stages.2.16.norm1.weight", "stages.2.16.token_mixer.qkv.weight", "stages.2.16.token_mixer.proj.weight", "stages.2.16.res_scale1.scale", "stages.2.16.norm2.weight", "stages.2.16.mlp.fc1.weight", "stages.2.16.mlp.act.scale", "stages.2.16.mlp.act.bias", "stages.2.16.mlp.fc2.weight", "stages.2.16.res_scale2.scale", "stages.2.17.norm1.weight", "stages.2.17.token_mixer.qkv.weight", "stages.2.17.token_mixer.proj.weight", "stages.2.17.res_scale1.scale", "stages.2.17.norm2.weight", "stages.2.17.mlp.fc1.weight", "stages.2.17.mlp.act.scale", "stages.2.17.mlp.act.bias", "stages.2.17.mlp.fc2.weight", "stages.2.17.res_scale2.scale", "stages.3.0.norm1.weight", "stages.3.0.token_mixer.qkv.weight", "stages.3.0.token_mixer.proj.weight", "stages.3.0.res_scale1.scale", "stages.3.0.norm2.weight", "stages.3.0.mlp.fc1.weight", "stages.3.0.mlp.act.scale", "stages.3.0.mlp.act.bias", "stages.3.0.mlp.fc2.weight", "stages.3.0.res_scale2.scale", "stages.3.1.norm1.weight", "stages.3.1.token_mixer.qkv.weight", "stages.3.1.token_mixer.proj.weight", "stages.3.1.res_scale1.scale", "stages.3.1.norm2.weight", "stages.3.1.mlp.fc1.weight", "stages.3.1.mlp.act.scale", "stages.3.1.mlp.act.bias", "stages.3.1.mlp.fc2.weight", "stages.3.1.res_scale2.scale", "stages.3.2.norm1.weight", "stages.3.2.token_mixer.qkv.weight", "stages.3.2.token_mixer.proj.weight", "stages.3.2.res_scale1.scale", "stages.3.2.norm2.weight", "stages.3.2.mlp.fc1.weight", "stages.3.2.mlp.act.scale", "stages.3.2.mlp.act.bias", "stages.3.2.mlp.fc2.weight", "stages.3.2.res_scale2.scale", "norm.weight", "norm.bias", "head.fc1.weight", "head.fc1.bias", "head.norm.weight", "head.norm.bias", "head.fc2.weight", "head.fc2.bias".
This error is related to the command model.load_state_dict(ckpt).
Before running this script I downloaded the checkpoint you mentioned in README: https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth
I have seen that in main.py script you load the model using checkpoint like this:
if args.resume is not None: ckpt = torch.load(args.resume, map_location='cpu')['state_dict'] ckpt["encoder.conv2.0.weight"] = ckpt.pop("encoder.conv2.1.weight") ckpt["encoder.conv2.0.bias"] = ckpt.pop("encoder.conv2.1.bias") model.load_state_dict(ckpt) print("load pretrained model, successfully!")
However, I cannot load the model this way as I'm getting error saying that there is no key 'state_dict' for ckpt object.
Also, I cannot pop "encoder.conv2.1.weight" and "encoder.conv2.1.bias" since there are no such keys in ckpt object.
Here are my python libs versions inside virtual environment:
Ubuntu 22.04 Python 3.10.12 Package Version ------------------------ ---------- certifi 2024.8.30 charset-normalizer 3.4.0 filelock 3.16.1 fsspec 2024.10.0 huggingface-hub 0.26.0 idna 3.10 Jinja2 3.1.4 MarkupSafe 3.0.2 mpmath 1.3.0 networkx 3.4.1 numpy 2.1.2 nvidia-cublas-cu12 12.4.5.8 nvidia-cuda-cupti-cu12 12.4.127 nvidia-cuda-nvrtc-cu12 12.4.127 nvidia-cuda-runtime-cu12 12.4.127 nvidia-cudnn-cu12 9.1.0.70 nvidia-cufft-cu12 11.2.1.3 nvidia-curand-cu12 10.3.5.147 nvidia-cusolver-cu12 11.6.1.9 nvidia-cusparse-cu12 12.3.1.170 nvidia-nccl-cu12 2.21.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.4.127 packaging 24.1 pillow 11.0.0 pip 22.0.2 PyYAML 6.0.2 requests 2.32.3 safetensors 0.4.5 setuptools 59.6.0 sympy 1.13.1 timm 1.0.11 torch 2.5.0 torchvision 0.20.0 tqdm 4.66.5 triton 3.1.0 typing_extensions 4.12.2 urllib3 2.2.3
Could you please provide us with a separate inference script that we can easily run on an arbitrary input image using the pretrained weights you mentioned?
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>
Could you add an easy to run inference script in the repository? I'm having trouble running the edge extraction on arbitrary input image using the checkpoint model you provided in the README.