Open kail85 opened 7 months ago
The method of using the trained Swin MAE model is consistent with the usual use of other pre-trained models. Downstream tasks can be performed directly using the official Swin-Unet code.
Here I show the function that loads the pre-trained model, and you'll see that it's basically the same as what's done in Swin-Unet.
def load_weight(model, device, path):
if path == '':
print('No pre-training weights are used.')
return model
assert os.path.exists(path)
full_dict = torch.load(path, map_location=device)['model']
model_dict = model.state_dict()
for key in list(full_dict):
if key.startswith('layers_up'):
del full_dict[key]
for key in list(full_dict):
if key.startswith('layers'):
current_layer_index = 2 - int(key[7:8])
if current_layer_index >= 0:
current_key = "layers_up." + str(current_layer_index) + key[8:]
full_dict[current_key] = full_dict[key]
for k in list(full_dict.keys()):
if k in model_dict:
if full_dict[k].shape != model_dict[k].shape:
print(f"Delete: '{k}'; "
f"Weight shape: {full_dict[k].shape}; Model shape: {model_dict[k].shape}")
del full_dict[k]
result = model.load_state_dict(full_dict, strict=False)
print(result)
return model
If feasible, could you provide the inference code for classification or segmentation using a trained model, please?