microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.75k stars 343 forks source link

Stitching: Access patch geo-transform in callback after predict #1407

Open adriantre opened 1 year ago

adriantre commented 1 year ago

Now that train.py supports both training and inference, I think we should close this PR. I think the best path forward is to modify predict_step in our trainers to store and stitch together predictions for each tile. Then we don't need PredictDataset or another script or anything else.

Originally posted by @adamjstewart in https://github.com/microsoft/torchgeo/issues/560#issuecomment-1522199639

Relevant issues and PRs: https://github.com/microsoft/torchgeo/issues/30, https://github.com/microsoft/torchgeo/issues/35, https://github.com/microsoft/torchgeo/pull/411

To stitch the prediction output on patches from the dataset we need to transform them to a common grid/reference system to be able to stitch them.

Let's say my predict_steps output is relative to the resolution of the sample-patch (e.g. a seg mask or boxes/points relative to the patch resolution and size). Then the transformation to the original (full) image could be performed using pyproj.Tranformer for coords, or rasterio.window.transform. These functions would also let us convert to any crs.

But then we need to reference this sample's window and source image file.

My initial thought on where to implement this was in Lightning Callback on_predict_batch_end. It has access to the trainer, and thus the dataset. Then, if we had a way to access the patch/window_transform and the source file, we could do this.

To my understanding, GeoDataModule.transfer_batch_to_device removes the reference to the patch/window due to it not being a tensor. For regular datasets, we can pass the index through. Would we need something similar here?

Thoughts? Are there any utilities for RasterDataset that I have yet to find that can help solve this?

adamjstewart commented 1 year ago

This is definitely doable, and I think you're on the right track. I'll just add a couple things to the pot and let you keep stirring.

One thing that should be clarified is whether you are trying to stitch together patches/chips taken from a single large tile/scene, or if you also want to stitch together entire scenes into a larger extent (e.g., a single map for all of North America). @calebrob6 is in favor of saving the latter for GDAL and only focusing on the former within TorchGeo, which I think I agree with, especially for memory constraint reasons.

To stitch the prediction output on patches from the dataset we need to transform them to a common grid/reference system to be able to stitch them.

At the moment, all GeoDatasets store R-tree bounds and load images/masks in a common CRS. This may very well change in the future: #409. I think this part is actually relatively easy to implement and should help with #409 as well. Currently we only store CRS. We should also store transform and window like you said. The tricky part is figuring out how to pass it through the datamodule to the trainer. As you noticed, we currently prevent that in transfer_batch_to_device.

For the transform and window, it's actually relatively easy since they are just matrices. We could convert back and forth from rasterio to PyTorch Tensor. But the CRS is different. I would need to look into the internals of Lightning, but I wonder if there isn't some way to keep some keys in each batch on the CPU instead of transferring them. My first thought would be to try to modify our override to only transfer some keys and skip others. If that doesn't work, maybe open a discussion on the Lightning repo and ask for advice on how to do this.

One other thought while we're on the topic. Ideally, we would be able to transfer all metadata from the file to the sample, including things like driver, dtype, nodata, compression, filename, etc. This gets even more complicated with RasterDatasets where each band is stored in a separate file. The idea is to make a prediction that replicates the input data as well as possible (same driver, dtype, etc.). We should also think about where we want to save predictions to (probably not the same location as the source data).

Excited to see progress along this direction!

adriantre commented 1 year ago

I have some ideas. We could pass everything we need through together with the batch by converting crs and window-transform (affine) to arrays, then back to affine matrix afterwards, as you say. For crs we store the epsg-code as scalar. Or we could store arrays for these in the dataset and do a lookup using some index that we pass with the batch instead.

I will start working on this when I get my computer back from a fire in our offices 😬

adamjstewart commented 1 year ago

For crs we store the epsg-code as scalar.

Not all CRS map to an EPSG code. Rasterio will attempt to find one that's close enough, but it won't be exact if we do this.

adriantre commented 1 year ago

Had another look at this and wrote down my thoughts.

Keeping it simple, I have a dataset consisting of one raster file product (only on crs). GridGeoSampler yields patches (windows) that are fed in batches to an object detector.

datamodule passes the RasterDataset query bbox with the batch, which I assume represents the window extent in the crs? (Given only one file)

predict_step returns pixel coordinates relative to each window.

on_predict_batch_end converts to crs coords and product pixel coords. This opens the raster-file to get the transform. Could probably be stored in the RasterDataset, but my understanding of how to handle multiple files are very limited.

on_predict_epoch_end stitches all predictions using either crs coords or product pixel coords. This might be NMS for object detection (removing duplicated du to overlapping windows). Then it save the result to a file.

Are my assumptions correct? I would think that many real-life scenarios involve only one product during prediction. I do not know how this could be applied when multiple files and union-datasets are used.

import os
from typing import Any, Dict

import lightning.pytorch as pl
import torch
from fiftyone.utils.geotiff import rasterio
from lightning import Callback, Trainer
from lightning.pytorch.utilities.apply_func import move_data_to_device
from torch import Tensor
from torchgeo.datamodules import GeoDataModule

class CustomDataModule(GeoDataModule):
    def transfer_batch_to_device(
        self, batch: Dict[str, Tensor], device: torch.device, dataloader_idx: int
    ) -> Dict[str, Tensor]:
        # don't need this, unless we need to transform to a new crs
        del batch["crs"]

        # bbox is the extent of each sample / patch which,
        # when using a single image file,
        # can be used to create rasterio window transform
        batch["bbox"] = torch.as_tensor([list(bbox) for bbox in batch["bbox"]])

        return move_data_to_device(batch, device)

class PatchStitchingCallback(Callback):
    def __init__(self, output_dir):
        super().__init__()
        self.output_dir = output_dir

    def on_predict_batch_end(
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        outputs: Any,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        # This only works when single file in dataset
        raster_filepath = trainer.datamodule.predict_sampler.hits[0].object
        with rasterio.open(raster_filepath) as src:
            product_transform = src.transform

        for bbox, output in zip(batch["bbox"], outputs):
            # bbox is the window/patch extent
            window = rasterio.windows.from_bounds(*bbox[:4], product_transform)
            window_transform = rasterio.windows.transform(window, product_transform)

            # given that our detection model outputs pixel coordinates
            # tensor([row1, col1],
            #        [row2, col2]])
            rows, cols = output["key_points"].T
            xs, ys = rasterio.transform.xy(window_transform, rows, cols)
            product_rows, product_cols = rasterio.transform.rowcol(
                product_transform, xs, ys
            )
            output["key_points_product_crs"] = torch.stack((xs, ys), dim=1)
            output["key_points_product_pixels"] = torch.stack(product_rows, product_cols)

    def write_on_epoch_end(self, predictions):
        torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))

    def on_predict_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:
        all_preds = trainer.predict_loop.predictions

        # Merge all predictions
        predictions = do_non_max_suppression(all_preds["key_points_product_pixels"])
        write_on_epoch_end(predictions)

# Usage:
trainer = Trainer(
    callbacks=[PatchStitchingCallback(output_dir="my_dir")],
    datamodule=CustomDataModule,
    model=...,
)

Took some inspiration from PyTorch Lightning PredictionWriter Callback.

adamjstewart commented 1 year ago

Keeping it simple, I have a dataset consisting of one raster file product (only on crs).

For a more general implementation, we'll need to support multiple files.

datamodule passes the RasterDataset query bbox with the batch, which I assume represents the window extent in the crs?

Correct.

predict_step returns pixel coordinates relative to each window.

I assume this satisfies #35?

This opens the raster-file to get the transform. Could probably be stored in the RasterDataset, but my understanding of how to handle multiple files are very limited.

Each file in the RasterDataset gets its own entry in the R-tree. We can store arbitrary objects (or tuples of objects) in the R-tree. We currently only store filename, bbox, and CRS, but we'll likely need to add transform and possibly more metadata. The hard part will be figuring out how to get this to work in Lightning where the module wants to transfer everything to the GPU.

on_predict_epoch_end stitches all predictions using either crs coords or product pixel coords.

I think this will be a problem. For the average multi-file dataset, we'll quickly exceed memory if we wait til the end of each epoch to save predictions.

adriantre commented 1 year ago

I assume this satisfies #35?

I leave it up to the model to output relative to input. If additional transformations are applied after the sampler query (before reaching predict_step), then we need to reverse these transforms. Best would be to apply these transform in the sampler such that the query is correct. This way we can trust that the bbox/query can be converted to a representative window transform.

Each file in the RasterDataset gets its own entry in the R-tree. We can store arbitrary objects (or tuples of objects) in the R-tree. We currently only store filename, bbox, and CRS, but we'll likely need to add transform and possibly more metadata. The hard part will be figuring out how to get this to work in Lightning where the module wants to transfer everything to the GPU.

I found that we don't need the product (file) transform in order to convert to RasterDataset.crs coordinates. Only if we need to convert to product pixel coordinates. I suggest we start with convertion to RasterDataset.crs first, then expand. In the meantime, the user can use gdal to translate it to anything, afterwards.

This is equivalent

window = rasterio.windows.from_bounds(*bbox[:4], product_transform)
window_transform = rasterio.windows.transform(window, product_transform)

with

patch_size = trainer.datamodule.patch_size
window_transform = rasterio.transform.from_bounds(*bbox[:4],  patch_size, patch_size)

NB! The ordering of bbox/sampler-query is not the same as what from_bounds needs. Will fix.

on_predict_epoch_end stitches all predictions using either crs coords or product pixel coords.

I think this will be a problem. For the average multi-file dataset, we'll quickly exceed memory if we wait til the end of each epoch to save predictions.

Seems like lightning keeps track of all predictions anyway, so memory may already be a problem in that case. But I found out that, on multi-gpu/node, on_predict_epoch_end only "has access" to the predictions on that node. So the output need to be saved to files, then another process needs to perform the actual stitching anyway.

adamjstewart commented 1 year ago

So the output need to be saved to files, then another process needs to perform the actual stitching anyway.

This is a good point, and would probably be sufficient for a first pass to solve any concerns about memory limitations. I would still love to ensure that GridGeoSampler distributes each file to the same device and flushes memory once a full file is complete, but this sounds like a lot more bookkeeping.

adriantre commented 1 year ago

Yes, this would indeed be good.

What do you think about the output reference?

  1. start with only supporting the RasterDataset common crs, our
  2. find a way to convert it to the crs of each sample's original file?

My guess is that most use-cases predict on only one file at a time, so I'm leaning towards 1. a the first step.

Ffor data fusion use-cases this may not be the case. Also depends on how the end results will be used. Some user may want vector points in WGS84 which would be easy to implement. But for seg-masks, choosing a common crs may lead to incorrect masks? Some might want to run predict on e.g. a whole Sentinel-2 datatake. These can be up to 15k km long, so I'm struggling to decide what the crs/reference frame should be for multi-file prediction.

adamjstewart commented 1 year ago

I would support both GeoDataset and NonGeoDataset. The former can use the same CRS as the dataset (then we won't need access to the original file) while the latter can simply save as a PNG.

adriantre commented 1 year ago

For NonGeoDatasets, are there an equivalent to GridGeoSampler that is used? I presume these images are also processed in patches, and need to be stitched. In that case we would need the stride/overlap used in order to stitch it.

I think the proposed callback above would save the predictions and corresponding sample indices, and another process would collect and stitch them, similarly to GeoDatasets.

adamjstewart commented 1 year ago

The majority of NonGeoDatasets already consist of patches and do not require stitching. For tile-based NonGeoDatasets, we have a _RandomNCrop function used during training but haven't yet implemented a gridded version for prediction. So you can ignore stitching for now.