allenai / satlas

Apache License 2.0
184 stars 19 forks source link

Suggestion to improve docs #20

Closed robmarkcole closed 9 months ago

robmarkcole commented 10 months ago

Currently the python code is not formatted and difficult to read in the docs - a couple of line addition will improve this, e.g.

import torch
import torchvision
model = torchvision.models.swin_transformer.swin_v2_b()
# Make sure to load a multi-image model here.
# Only the multi-image models are trained to provide robust features after max temporal pooling.
full_state_dict = torch.load('satlas-model-v1-lowres-multi.pth')
# Extract just the Swin backbone parameters from the full state dict.
swin_prefix = 'backbone.backbone.backbone.'
swin_state_dict = {k[len(swin_prefix):]: v for k, v in full_state_dict.items() if k.startswith(swin_prefix)}
model.load_state_dict(swin_state_dict)

# Assume im is shape (N, C, H, W), with N aligned images of the same location at different times.
# First get feature maps of each individual image.
x = im
outputs = []
for layer in model.features:
    x = layer(x)
    outputs.append(x.permute(0, 3, 1, 2))
feature_maps = [outputs[-7], outputs[-5], outputs[-3], outputs[-1]]
# Now apply max temporal pooling.
feature_maps = [
    m.amax(dim=0)
    for m in feature_maps
]
# feature_maps can be passed to a head, and the head or entire model can be trained to fine-tune on task-specific labels.

using ```python etc

favyen2 commented 9 months ago

Thanks, I have replaced the spaced Python code blocks with fenced ones.