PatBall1 / detectree2

Python package for automatic tree crown delineation based on the Detectron2 implementation of Mask R-CNN
https://patball1.github.io/detectree2/
MIT License
161 stars 39 forks source link

Train/predict on multispectral (4+ band) imagery #147

Closed PatBall1 closed 1 month ago

PatBall1 commented 2 months ago

Introduce data readers that allow multispectral images to be used in training and prediction. There will be a possible issue with using pre-trained base models as has been done previously.

See: https://github.com/facebookresearch/detectron2/issues/698 https://github.com/facebookresearch/detectron2/issues/1292 https://github.com/facebookresearch/detectron2/issues/2062

Also: https://detectron2.readthedocs.io/en/latest/tutorials/data_loading.html

PatBall1 commented 2 months ago

To modify the existing training routine to handle images with 4 or more bands, you need to make a few changes to the data loading and processing pipeline. Specifically, you need to ensure that the model can accept multi-band images and correctly process them during both training and inference.

Here’s how you can do this:

Step 1: Modify the DatasetMapper to Handle Multi-Band Images

You need to customize the DatasetMapper used in Detectron2 to handle multi-band images. This involves loading the images with all their bands and ensuring that they are passed correctly to the model.

  1. Custom DatasetMapper: Create a custom mapper that reads all the bands of the image and passes them to the model.
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
import torch

class MultiBandDatasetMapper:
    def __init__(self, cfg, is_train=True, augmentations=None):
        self.is_train = is_train
        self.augmentations = T.AugmentationList(augmentations) if augmentations else None

    def __call__(self, dataset_dict):
        dataset_dict = dataset_dict.copy()  # Make a copy of the dataset dict
        image = utils.read_image(dataset_dict["file_name"], format="BGR")  # This reads the image
        image = self.load_all_bands(dataset_dict["file_name"])  # Custom method to load all bands

        if self.augmentations:
            image, transforms = T.apply_augmentations(self.augmentations, image)
            dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
        else:
            dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))

        annos = [
            utils.transform_instance_annotations(annotation, transforms, image.shape[:2])
            for annotation in dataset_dict.pop("annotations")
        ]
        dataset_dict["instances"] = utils.annotations_to_instances(annos, image.shape[:2])
        return dataset_dict

    def load_all_bands(self, image_path):
        """Load all bands of the image using rasterio and return as a numpy array."""
        with rasterio.open(image_path) as src:
            image = src.read()  # This will read all bands
            # Normalize the bands if necessary
            image = image.astype(np.float32) / 255.0
            # Transpose to HWC format
            image = np.transpose(image, (1, 2, 0))
        return image

Step 2: Integrate the Custom DatasetMapper into the Training Pipeline

Now, you need to modify the build_train_loader function to use this MultiBandDatasetMapper.

def build_train_loader(cls, cfg):
    """Summary.

    Args:
        cfg (_type_): _description_

    Returns:
        _type_: _description_
    """
    augmentations = [
        T.RandomBrightness(0.8, 1.8),
        T.RandomContrast(0.6, 1.3),
        T.RandomSaturation(0.8, 1.4),
        T.RandomRotation(angle=[90, 90], expand=False),
        T.RandomLighting(0.7),
        T.RandomFlip(prob=0.4, horizontal=True, vertical=False),
        T.RandomFlip(prob=0.4, horizontal=False, vertical=True),
    ]

    if cfg.RESIZE:
        augmentations.append(T.Resize((1000, 1000)))
    elif cfg.RESIZE == "random":
        for i, datas in enumerate(DatasetCatalog.get(cfg.DATASETS.TRAIN[0])):
            location = datas['file_name']
            size = cv2.imread(location).shape[0]
            break
        print("ADD RANDOM RESIZE WITH SIZE = ", size)
        augmentations.append(T.ResizeScale(0.6, 1.4, size, size))

    return build_detection_train_loader(
        cfg,
        mapper=MultiBandDatasetMapper(
            cfg,
            is_train=True,
            augmentations=augmentations,
        ),
    )

Step 3: Ensure the Model Can Accept Multi-Band Images

Detectron2 models expect 3-channel (RGB) inputs by default. To work with multi-band images, you need to adjust the model’s input layer. This requires a bit more customization:

  1. Modify the Input Layer: You’ll need to modify the first layer of the model to accept more input channels. If you're using a ResNet backbone, this could look something like this:
from detectron2.modeling import build_model

# Update config to match the number of input channels
cfg.INPUT.FORMAT = "BGR"  # This can be kept as it is if you're loading the bands as a 3+ channel image
cfg.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675] + [0.0] * (num_bands - 3)
cfg.MODEL.PIXEL_STD = [1.0, 1.0, 1.0] + [1.0] * (num_bands - 3)

# Build the model
model = build_model(cfg)

# Modify the first conv layer to accept more channels
with torch.no_grad():
    old_weight = model.backbone.bottom_up.stem.conv1.weight
    new_weight = torch.nn.Parameter(torch.cat([old_weight, old_weight[:, :num_bands-3, :, :]], dim=1))
    model.backbone.bottom_up.stem.conv1 = torch.nn.Conv2d(
        num_bands, old_weight.shape[0], kernel_size=7, stride=2, padding=3, bias=False
    )
    model.backbone.bottom_up.stem.conv1.weight = new_weight

Step 4: Update the Training Routine

Finally, ensure that your MyTrainer class and other parts of the codebase correctly integrate these changes. This includes using the modified model and data loader.

Full Example Integration

Here’s a simplified main function integrating the changes:

if __name__ == "__main__":
    train_location = "/path/to/train/dataset"
    register_train_data(train_location, "Paracou", 1)

    model = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
    trains = ("Paracou_train",)
    tests = ("Paracou_val",)
    out_dir = "/path/to/output"

    cfg = setup_cfg(model, trains, tests, eval_period=100, max_iter=3000, out_dir=out_dir)

    # Adjust model for multi-band input
    cfg.INPUT.FORMAT = "BGR"
    cfg.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675] + [0.0] * (num_bands - 3)
    cfg.MODEL.PIXEL_STD = [1.0, 1.0, 1.0] + [1.0] * (num_bands - 3)

    trainer = MyTrainer(cfg, patience=4)
    trainer.resume_or_load(resume=False)
    trainer.train()

Key Changes:

This setup allows Detectron2 to train on images with more than 3 channels, such as multi-spectral or hyper-spectral images stored in TIFF format.