Currently the python code is not formatted and difficult to read in the docs - a couple of line addition will improve this, e.g.
import torch
import torchvision
model = torchvision.models.swin_transformer.swin_v2_b()
# Make sure to load a multi-image model here.
# Only the multi-image models are trained to provide robust features after max temporal pooling.
full_state_dict = torch.load('satlas-model-v1-lowres-multi.pth')
# Extract just the Swin backbone parameters from the full state dict.
swin_prefix = 'backbone.backbone.backbone.'
swin_state_dict = {k[len(swin_prefix):]: v for k, v in full_state_dict.items() if k.startswith(swin_prefix)}
model.load_state_dict(swin_state_dict)
# Assume im is shape (N, C, H, W), with N aligned images of the same location at different times.
# First get feature maps of each individual image.
x = im
outputs = []
for layer in model.features:
x = layer(x)
outputs.append(x.permute(0, 3, 1, 2))
feature_maps = [outputs[-7], outputs[-5], outputs[-3], outputs[-1]]
# Now apply max temporal pooling.
feature_maps = [
m.amax(dim=0)
for m in feature_maps
]
# feature_maps can be passed to a head, and the head or entire model can be trained to fine-tune on task-specific labels.
Currently the python code is not formatted and difficult to read in the docs - a couple of line addition will improve this, e.g.
using ```python etc