Zian-Xu / Swin-MAE

Pytorch implementation of Swin MAE https://arxiv.org/abs/2212.13805
62 stars 12 forks source link

Inference for classification or segmentation #8

Open kail85 opened 7 months ago

kail85 commented 7 months ago

If feasible, could you provide the inference code for classification or segmentation using a trained model, please?

Zian-Xu commented 7 months ago

The method of using the trained Swin MAE model is consistent with the usual use of other pre-trained models. Downstream tasks can be performed directly using the official Swin-Unet code.

Here I show the function that loads the pre-trained model, and you'll see that it's basically the same as what's done in Swin-Unet.

def load_weight(model, device, path):
    if path == '':
        print('No pre-training weights are used.')
        return model
    assert os.path.exists(path)
    full_dict = torch.load(path, map_location=device)['model']
    model_dict = model.state_dict()

    for key in list(full_dict):
        if key.startswith('layers_up'):
            del full_dict[key]

    for key in list(full_dict):
        if key.startswith('layers'):
            current_layer_index = 2 - int(key[7:8])
            if current_layer_index >= 0:
                current_key = "layers_up." + str(current_layer_index) + key[8:]
                full_dict[current_key] = full_dict[key]

    for k in list(full_dict.keys()):
        if k in model_dict:
            if full_dict[k].shape != model_dict[k].shape:
                print(f"Delete: '{k}'; "
                      f"Weight shape: {full_dict[k].shape}; Model shape: {model_dict[k].shape}")
                del full_dict[k]

    result = model.load_state_dict(full_dict, strict=False)
    print(result)
    return model