Closed robmarkcole closed 9 months ago
I've had a stab at it with the following:
import os
dist_params = dict(backend="nccl")
log_level = "INFO"
load_from = None
resume_from = None
cudnn_benchmark = True
custom_imports = dict(imports=["geospatial_fm"])
num_frames = 3
img_size = 224
num_workers = 2
# model
# TO BE DEFINED BY USER: model path
# pretrained_weights_path = "/teamspace/studios/this_studio/hls-foundation-os/prithvi/Prithvi_100M.pt"
num_layers = 6
max_epochs = 80
eval_epoch_interval = 5
loss_weights_multi = [
0.386375,
0.661126,
0.548184,
0.640482,
0.876862,
0.925186,
3.249462,
1.542289,
2.175141,
2.272419,
3.062762,
3.626097,
1.198702,
]
loss_func = dict(
type="CrossEntropyLoss",
use_sigmoid=False,
class_weight=loss_weights_multi,
avg_non_ignore=True,
)
# TO BE DEFINED BY USER: Save directory
experiment = "classification"
project_dir = "/teamspace/studios/this_studio/project_classification_unet"
work_dir = os.path.join(project_dir, experiment)
save_path = work_dir
dataset_type = "GeospatialDataset"
# TO BE DEFINED BY USER: data directory
data_root = "/teamspace/studios/this_studio/data/multi-temporal-crop-classification/"
splits = dict(
train="training_data.txt",
val="validation_data.txt",
test="validation_data.txt",
)
img_norm_cfg = dict(
means=[
494.905781,
815.239594,
924.335066,
2968.881459,
2634.621962,
1739.579917,
494.905781,
815.239594,
924.335066,
2968.881459,
2634.621962,
1739.579917,
494.905781,
815.239594,
924.335066,
2968.881459,
2634.621962,
1739.579917,
],
stds=[
284.925432,
357.84876,
575.566823,
896.601013,
951.900334,
921.407808,
284.925432,
357.84876,
575.566823,
896.601013,
951.900334,
921.407808,
284.925432,
357.84876,
575.566823,
896.601013,
951.900334,
921.407808,
],
)
bands = [0, 1, 2, 3, 4, 5]
tile_size = 224
orig_nsize = 512
crop_size = (tile_size, tile_size)
train_pipeline = [
dict(type="LoadGeospatialImageFromFile", to_float32=True),
dict(type="LoadGeospatialAnnotations", reduce_zero_label=True),
dict(type="RandomFlip", prob=0.5),
dict(type="ToTensor", keys=["img", "gt_semantic_seg"]),
dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
dict(type="TorchNormalize", **img_norm_cfg),
dict(type="TorchRandomCrop", crop_size=crop_size),
dict(
type="Reshape",
keys=["img"],
new_shape=(len(bands)*num_frames, tile_size, tile_size),
),
dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)),
dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"),
dict(type="Collect", keys=["img", "gt_semantic_seg"]),
]
test_pipeline = [
dict(type="LoadGeospatialImageFromFile", to_float32=True),
dict(type="ToTensor", keys=["img"]),
dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
dict(type="TorchNormalize", **img_norm_cfg),
dict(
type="Reshape",
keys=["img"],
new_shape=(len(bands)*num_frames, tile_size, tile_size),
),
dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"),
dict(
type="CollectTestList",
keys=["img"],
meta_keys=[
"img_info",
"seg_fields",
"img_prefix",
"seg_prefix",
"filename",
"ori_filename",
"img",
"img_shape",
"ori_shape",
"pad_shape",
"scale_factor",
"img_norm_cfg",
],
),
]
CLASSES = (
"Natural Vegetation",
"Forest",
"Corn",
"Soybeans",
"Wetlands",
"Developed/Barren",
"Open Water",
"Winter Wheat",
"Alfalfa",
"Fallow/Idle Cropland",
"Cotton",
"Sorghum",
"Other",
)
dataset = "GeospatialDataset"
data = dict(
samples_per_gpu=4, # 8,
workers_per_gpu=4,
train=dict(
type=dataset,
CLASSES=CLASSES,
reduce_zero_label=True,
data_root=data_root,
img_dir="training_chips",
ann_dir="training_chips",
pipeline=train_pipeline,
img_suffix="_merged.tif",
seg_map_suffix=".mask.tif",
split=splits["train"],
),
val=dict(
type=dataset,
CLASSES=CLASSES,
reduce_zero_label=True,
data_root=data_root,
img_dir="validation_chips",
ann_dir="validation_chips",
pipeline=test_pipeline,
img_suffix="_merged.tif",
seg_map_suffix=".mask.tif",
split=splits["val"],
),
test=dict(
type=dataset,
CLASSES=CLASSES,
reduce_zero_label=True,
data_root=data_root,
img_dir="validation_chips",
ann_dir="validation_chips",
pipeline=test_pipeline,
img_suffix="_merged.tif",
seg_map_suffix=".mask.tif",
split=splits["val"],
),
)
optimizer = dict(type="Adam", lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy="poly",
warmup="linear",
warmup_iters=1500,
warmup_ratio=1e-06,
power=1.0,
min_lr=0.0,
by_epoch=False,
)
log_config = dict(
interval=10, hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")]
)
checkpoint_config = dict(by_epoch=True, interval=100, out_dir=save_path)
evaluation = dict(
interval=eval_epoch_interval,
metric="mIoU",
pre_eval=True,
save_best="mIoU",
by_epoch=True,
)
reduce_train_set = dict(reduce_train_set=False)
reduce_factor = dict(reduce_factor=1)
runner = dict(type="EpochBasedRunner", max_epochs=max_epochs)
workflow = [("train", 1)]
norm_cfg = dict(type="BN", requires_grad=True)
# from https://github.com/open-mmlab/mmsegmentation/issues/289
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='UNet',
in_channels=len(bands)*num_frames,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False),
decode_head=dict(
type='FCNHead',
in_channels=64,
in_index=4,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=len(CLASSES),
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=loss_func
),
auxiliary_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=len(CLASSES),
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=loss_func
),
train_cfg=dict(),
test_cfg=dict(
mode="slide",
stride=(int(tile_size / 2), int(tile_size / 2)),
crop_size=(tile_size, tile_size),
),
)
auto_resume = False
This trains for 5 epochs then fails on validation with:
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/mmseg/models/segmentors/encoder_decoder.py", line 253, in inference
flip = img_meta[0]['flip']
KeyError: 'flip'
No idea how to proceed, appears to be https://github.com/open-mmlab/mmsegmentation/issues/231 although I did not change the augmentations
Hi @robmarkcole, the UNet baseline is available here. @samKhallaghi can answer any questions you might have.
@robmarkcole I'm closing this in favor of the tickets you have opened on the Unet repo..
I wish to reproduce the Unet results from
Further information is requested, since mmsegmentation offers many Unets. The complete config used would be ideal