allenai / satlas

Apache License 2.0
184 stars 19 forks source link

Config and model mismatch in CustomInference with aerial images. #44

Closed utkarshmall13 closed 1 month ago

utkarshmall13 commented 3 months ago

Hi, I was trying to run a segmentation model on a custom image using these [instructions (https://github.com/allenai/satlas/blob/main/SatlasPretrain.md#visualizing-outputs-on-new-images) and CustomInference.

Like this: wget -O models/highres/aerial_swinb_si.pth "https://huggingface.co/allenai/satlas-pretrain/resolve/main/aerial_swinb_si.pth?download=true"

import json
import skimage.io
import torch
import torchvision

import satlas.model.model
import satlas.model.evaluate

config_path = 'configs/aerial/swinb_si.txt'
weights_path = 'models/highres/aerial_swinb_si.pth'
image_path = 'image.jpg'
...

However, while I'm able to load the model sentinel2_swinb_si_rgb.pth with the suggested config file configs/sentinel2/swinb_si_rgb.txt. The same does not work for aerial image models using aerial_swinb_si.pth and the config file configs/aerial/swinb_si.txt. It seems like the tasks and number of heads are different between the config file and the model checkpoint. I've also tried other config files such as swinb_si_pretrain.txt but none of them match the checkpoint.

Would it be possible to get the correct checkpoint or the corresponding config file?

favyen2 commented 3 months ago

It looks like aerial_swinb_si.pth and aerial_swinb_mi.pth are swapped on Hugging Face. We will fix those weights. In the meantime this worked for me:

import json
import skimage.io
import torch
import torchvision

import satlas.model.model
import satlas.model.evaluate

# Locations of model config and weights, and the 8-bit RGB image to run inference on.
config_path = 'configs/aerial/swinb_si.txt'
weights_path = 'aerial_swinb_mi.pth'
image_path = 'image.jpg'

# Read config and initialize the model.
with open(config_path, 'r') as f:
    config = json.load(f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for spec in config['Tasks']:
    if 'Task' not in spec:
        spec['Task'] = satlas.model.dataset.tasks[spec['Name']]
model = satlas.model.model.Model({
    'config': config['Model'],
    'channels': config['Channels'],
    'tasks': config['Tasks'],
})
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
favyen2 commented 1 month ago

The filenames were fixed.

utkarshmall13 commented 1 month ago

Thanks a lot! this works!