Closed PatBall1 closed 1 month ago
To modify the existing training routine to handle images with 4 or more bands, you need to make a few changes to the data loading and processing pipeline. Specifically, you need to ensure that the model can accept multi-band images and correctly process them during both training and inference.
Here’s how you can do this:
DatasetMapper
to Handle Multi-Band ImagesYou need to customize the DatasetMapper
used in Detectron2 to handle multi-band images. This involves loading the images with all their bands and ensuring that they are passed correctly to the model.
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
import torch
class MultiBandDatasetMapper:
def __init__(self, cfg, is_train=True, augmentations=None):
self.is_train = is_train
self.augmentations = T.AugmentationList(augmentations) if augmentations else None
def __call__(self, dataset_dict):
dataset_dict = dataset_dict.copy() # Make a copy of the dataset dict
image = utils.read_image(dataset_dict["file_name"], format="BGR") # This reads the image
image = self.load_all_bands(dataset_dict["file_name"]) # Custom method to load all bands
if self.augmentations:
image, transforms = T.apply_augmentations(self.augmentations, image)
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
else:
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
annos = [
utils.transform_instance_annotations(annotation, transforms, image.shape[:2])
for annotation in dataset_dict.pop("annotations")
]
dataset_dict["instances"] = utils.annotations_to_instances(annos, image.shape[:2])
return dataset_dict
def load_all_bands(self, image_path):
"""Load all bands of the image using rasterio and return as a numpy array."""
with rasterio.open(image_path) as src:
image = src.read() # This will read all bands
# Normalize the bands if necessary
image = image.astype(np.float32) / 255.0
# Transpose to HWC format
image = np.transpose(image, (1, 2, 0))
return image
Now, you need to modify the build_train_loader
function to use this MultiBandDatasetMapper
.
def build_train_loader(cls, cfg):
"""Summary.
Args:
cfg (_type_): _description_
Returns:
_type_: _description_
"""
augmentations = [
T.RandomBrightness(0.8, 1.8),
T.RandomContrast(0.6, 1.3),
T.RandomSaturation(0.8, 1.4),
T.RandomRotation(angle=[90, 90], expand=False),
T.RandomLighting(0.7),
T.RandomFlip(prob=0.4, horizontal=True, vertical=False),
T.RandomFlip(prob=0.4, horizontal=False, vertical=True),
]
if cfg.RESIZE:
augmentations.append(T.Resize((1000, 1000)))
elif cfg.RESIZE == "random":
for i, datas in enumerate(DatasetCatalog.get(cfg.DATASETS.TRAIN[0])):
location = datas['file_name']
size = cv2.imread(location).shape[0]
break
print("ADD RANDOM RESIZE WITH SIZE = ", size)
augmentations.append(T.ResizeScale(0.6, 1.4, size, size))
return build_detection_train_loader(
cfg,
mapper=MultiBandDatasetMapper(
cfg,
is_train=True,
augmentations=augmentations,
),
)
Detectron2 models expect 3-channel (RGB) inputs by default. To work with multi-band images, you need to adjust the model’s input layer. This requires a bit more customization:
from detectron2.modeling import build_model
# Update config to match the number of input channels
cfg.INPUT.FORMAT = "BGR" # This can be kept as it is if you're loading the bands as a 3+ channel image
cfg.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675] + [0.0] * (num_bands - 3)
cfg.MODEL.PIXEL_STD = [1.0, 1.0, 1.0] + [1.0] * (num_bands - 3)
# Build the model
model = build_model(cfg)
# Modify the first conv layer to accept more channels
with torch.no_grad():
old_weight = model.backbone.bottom_up.stem.conv1.weight
new_weight = torch.nn.Parameter(torch.cat([old_weight, old_weight[:, :num_bands-3, :, :]], dim=1))
model.backbone.bottom_up.stem.conv1 = torch.nn.Conv2d(
num_bands, old_weight.shape[0], kernel_size=7, stride=2, padding=3, bias=False
)
model.backbone.bottom_up.stem.conv1.weight = new_weight
Finally, ensure that your MyTrainer
class and other parts of the codebase correctly integrate these changes. This includes using the modified model and data loader.
Here’s a simplified main function integrating the changes:
if __name__ == "__main__":
train_location = "/path/to/train/dataset"
register_train_data(train_location, "Paracou", 1)
model = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
trains = ("Paracou_train",)
tests = ("Paracou_val",)
out_dir = "/path/to/output"
cfg = setup_cfg(model, trains, tests, eval_period=100, max_iter=3000, out_dir=out_dir)
# Adjust model for multi-band input
cfg.INPUT.FORMAT = "BGR"
cfg.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675] + [0.0] * (num_bands - 3)
cfg.MODEL.PIXEL_STD = [1.0, 1.0, 1.0] + [1.0] * (num_bands - 3)
trainer = MyTrainer(cfg, patience=4)
trainer.resume_or_load(resume=False)
trainer.train()
This setup allows Detectron2 to train on images with more than 3 channels, such as multi-spectral or hyper-spectral images stored in TIFF format.
Introduce data readers that allow multispectral images to be used in training and prediction. There will be a possible issue with using pre-trained base models as has been done previously.
See: https://github.com/facebookresearch/detectron2/issues/698 https://github.com/facebookresearch/detectron2/issues/1292 https://github.com/facebookresearch/detectron2/issues/2062
Also: https://detectron2.readthedocs.io/en/latest/tutorials/data_loading.html