NASA-IMPACT / hls-foundation-os

This repository contains examples of fine-tuning Harmonized Landsat and Sentinel-2 (HLS) Prithvi foundation model.
Apache License 2.0
319 stars 83 forks source link

Request info on Unet baseline #50

Closed robmarkcole closed 9 months ago

robmarkcole commented 10 months ago

I wish to reproduce the Unet results from

image

Further information is requested, since mmsegmentation offers many Unets. The complete config used would be ideal

robmarkcole commented 10 months ago

I've had a stab at it with the following:

import os

dist_params = dict(backend="nccl")
log_level = "INFO"
load_from = None
resume_from = None
cudnn_benchmark = True
custom_imports = dict(imports=["geospatial_fm"])
num_frames = 3
img_size = 224
num_workers = 2

# model
# TO BE DEFINED BY USER: model path
# pretrained_weights_path = "/teamspace/studios/this_studio/hls-foundation-os/prithvi/Prithvi_100M.pt"
num_layers = 6
max_epochs = 80
eval_epoch_interval = 5

loss_weights_multi = [
    0.386375,
    0.661126,
    0.548184,
    0.640482,
    0.876862,
    0.925186,
    3.249462,
    1.542289,
    2.175141,
    2.272419,
    3.062762,
    3.626097,
    1.198702,
]
loss_func = dict(
    type="CrossEntropyLoss",
    use_sigmoid=False,
    class_weight=loss_weights_multi,
    avg_non_ignore=True,
)

# TO BE DEFINED BY USER: Save directory
experiment = "classification"
project_dir = "/teamspace/studios/this_studio/project_classification_unet"
work_dir = os.path.join(project_dir, experiment)
save_path = work_dir

dataset_type = "GeospatialDataset"

# TO BE DEFINED BY USER: data directory
data_root = "/teamspace/studios/this_studio/data/multi-temporal-crop-classification/"

splits = dict(
    train="training_data.txt",
    val="validation_data.txt",
    test="validation_data.txt",
)

img_norm_cfg = dict(
    means=[
        494.905781,
        815.239594,
        924.335066,
        2968.881459,
        2634.621962,
        1739.579917,
        494.905781,
        815.239594,
        924.335066,
        2968.881459,
        2634.621962,
        1739.579917,
        494.905781,
        815.239594,
        924.335066,
        2968.881459,
        2634.621962,
        1739.579917,
    ],
    stds=[
        284.925432,
        357.84876,
        575.566823,
        896.601013,
        951.900334,
        921.407808,
        284.925432,
        357.84876,
        575.566823,
        896.601013,
        951.900334,
        921.407808,
        284.925432,
        357.84876,
        575.566823,
        896.601013,
        951.900334,
        921.407808,
    ],
)

bands = [0, 1, 2, 3, 4, 5]

tile_size = 224
orig_nsize = 512
crop_size = (tile_size, tile_size)
train_pipeline = [
    dict(type="LoadGeospatialImageFromFile", to_float32=True),
    dict(type="LoadGeospatialAnnotations", reduce_zero_label=True),
    dict(type="RandomFlip", prob=0.5),
    dict(type="ToTensor", keys=["img", "gt_semantic_seg"]),
    dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
    dict(type="TorchNormalize", **img_norm_cfg),
    dict(type="TorchRandomCrop", crop_size=crop_size),
    dict(
        type="Reshape",
        keys=["img"],
        new_shape=(len(bands)*num_frames, tile_size, tile_size),
    ),
    dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)),
    dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"),
    dict(type="Collect", keys=["img", "gt_semantic_seg"]),
]

test_pipeline = [
    dict(type="LoadGeospatialImageFromFile", to_float32=True),
    dict(type="ToTensor", keys=["img"]),
    dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
    dict(type="TorchNormalize", **img_norm_cfg),
    dict(
        type="Reshape",
        keys=["img"],
        new_shape=(len(bands)*num_frames, tile_size, tile_size),
    ),
    dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"),
    dict(
        type="CollectTestList",
        keys=["img"],
        meta_keys=[
            "img_info",
            "seg_fields",
            "img_prefix",
            "seg_prefix",
            "filename",
            "ori_filename",
            "img",
            "img_shape",
            "ori_shape",
            "pad_shape",
            "scale_factor",
            "img_norm_cfg",
        ],
    ),
]

CLASSES = (
    "Natural Vegetation",
    "Forest",
    "Corn",
    "Soybeans",
    "Wetlands",
    "Developed/Barren",
    "Open Water",
    "Winter Wheat",
    "Alfalfa",
    "Fallow/Idle Cropland",
    "Cotton",
    "Sorghum",
    "Other",
)

dataset = "GeospatialDataset"
data = dict(
    samples_per_gpu=4, # 8,
    workers_per_gpu=4,
    train=dict(
        type=dataset,
        CLASSES=CLASSES,
        reduce_zero_label=True,
        data_root=data_root,
        img_dir="training_chips",
        ann_dir="training_chips",
        pipeline=train_pipeline,
        img_suffix="_merged.tif",
        seg_map_suffix=".mask.tif",
        split=splits["train"],
    ),
    val=dict(
        type=dataset,
        CLASSES=CLASSES,
        reduce_zero_label=True,
        data_root=data_root,
        img_dir="validation_chips",
        ann_dir="validation_chips",
        pipeline=test_pipeline,
        img_suffix="_merged.tif",
        seg_map_suffix=".mask.tif",
        split=splits["val"],
    ),
    test=dict(
        type=dataset,
        CLASSES=CLASSES,
        reduce_zero_label=True,
        data_root=data_root,
        img_dir="validation_chips",
        ann_dir="validation_chips",
        pipeline=test_pipeline,
        img_suffix="_merged.tif",
        seg_map_suffix=".mask.tif",
        split=splits["val"],
    ),
)

optimizer = dict(type="Adam", lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy="poly",
    warmup="linear",
    warmup_iters=1500,
    warmup_ratio=1e-06,
    power=1.0,
    min_lr=0.0,
    by_epoch=False,
)
log_config = dict(
    interval=10, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
)

checkpoint_config = dict(by_epoch=True, interval=100, out_dir=save_path)

evaluation = dict(
    interval=eval_epoch_interval,
    metric="mIoU",
    pre_eval=True,
    save_best="mIoU",
    by_epoch=True,
)
reduce_train_set = dict(reduce_train_set=False)
reduce_factor = dict(reduce_factor=1)
runner = dict(type="EpochBasedRunner", max_epochs=max_epochs)
workflow = [("train", 1)]
norm_cfg = dict(type="BN", requires_grad=True)

# from https://github.com/open-mmlab/mmsegmentation/issues/289
model = dict(
    type='EncoderDecoder',
    pretrained=None,
    backbone=dict(
        type='UNet',
        in_channels=len(bands)*num_frames,
        base_channels=64,
        num_stages=5,
        strides=(1, 1, 1, 1, 1),
        enc_num_convs=(2, 2, 2, 2, 2),
        dec_num_convs=(2, 2, 2, 2),
        downsamples=(True, True, True, True),
        enc_dilations=(1, 1, 1, 1, 1),
        dec_dilations=(1, 1, 1, 1),
        with_cp=False,
        conv_cfg=None,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='ReLU'),
        upsample_cfg=dict(type='InterpConv'),
        norm_eval=False),
    decode_head=dict(
        type='FCNHead',
        in_channels=64,
        in_index=4,
        channels=64,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=len(CLASSES),
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=loss_func
        ),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=128,
        in_index=3,
        channels=64,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=len(CLASSES),
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=loss_func
        ),
        train_cfg=dict(),
        test_cfg=dict(
            mode="slide",
            stride=(int(tile_size / 2), int(tile_size / 2)),
            crop_size=(tile_size, tile_size),
        ),
    )

auto_resume = False

This trains for 5 epochs then fails on validation with:

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/mmseg/models/segmentors/encoder_decoder.py", line 253, in inference
    flip = img_meta[0]['flip']
KeyError: 'flip'

No idea how to proceed, appears to be https://github.com/open-mmlab/mmsegmentation/issues/231 although I did not change the augmentations

HamedAlemo commented 10 months ago

Hi @robmarkcole, the UNet baseline is available here. @samKhallaghi can answer any questions you might have.

HamedAlemo commented 9 months ago

@robmarkcole I'm closing this in favor of the tickets you have opened on the Unet repo..