microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.62k stars 322 forks source link

ArcGIS Pro support #451

Open tritolol opened 2 years ago

tritolol commented 2 years ago

Hi, I'm currently working on an interface between ESRI ArcGIS Pro and torchgeo. The basic functionality is done so far, I'm able to sample raster layers directly off a running ArcGIS application using torchgeo samplers. Technically, this is just a new dataset class inheriting from GeoDataset, currently only about 50 lines of code. At the moment I'm planning to use this only for inference purposes, since training within an ArcGIS Pro instance isn't feasible imo.

The motivation behind writing this interface is to provide city administration staff who work with this software a simple way of utilizing arbitrary pretrained pytorch models. Currently, I'm using a python toolbox script to let the user define input and output layers which contains just a few lines of code, which is a nice thing.

I'm aware of the Deep Learning toolbox in ESRI's Image Analyst package. However, I didn't find an easy way to incorporate custom trained models into this workflow.

Would you be interested in including this into torchgeo?

isaaccorley commented 2 years ago

Could you provide a gist of the implementation?

adamjstewart commented 2 years ago

Would you be interested in including this into torchgeo?

Yes!

tritolol commented 2 years ago

Could you provide a gist of the implementation?

Sure arcgisdatasets.py:

import arcpy
from torchgeo.datasets.geo import GeoDataset
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast
from torchgeo.datasets.utils import BoundingBox
from rasterio.crs import CRS
import torch
import torchvision
import sys

class ArcGisRasterDataset(GeoDataset):
    layer_name = ""

    def __init__(
        self,
        layer_name,
        crs: Optional[CRS] = None,
        res: Optional[float] = None,
        transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
    ) -> None:

        super().__init__(transforms)

        self.layer_name = layer_name

        top = float(arcpy.GetRasterProperties_management(self.layer_name, "TOP").getOutput(0).replace(",", "."))
        bottom = float(arcpy.GetRasterProperties_management(self.layer_name, "BOTTOM").getOutput(0).replace(",", "."))
        left = float(arcpy.GetRasterProperties_management(self.layer_name, "LEFT").getOutput(0).replace(",", "."))
        right = float(arcpy.GetRasterProperties_management(self.layer_name, "RIGHT").getOutput(0).replace(",", "."))

        coords = (left, right, bottom, top, 0, sys.maxsize)
        self.extent = coords
        self.index.insert(0, coords)

        self.source_res = float(arcpy.GetRasterProperties_management(self.layer_name, "CELLSIZEX").getOutput(0).replace(",", "."))

        if res == None:
            self.res = self.source_res
        else:
            self.res = res

        spatial_ref = arcpy.Describe(self.layer_name).spatialReference.factoryCode
        self._crs = CRS.from_epsg(spatial_ref)

    def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
        lower_left = arcpy.Point(query.minx, query.miny)    
        source_query_cell_size = (round((query.maxx - query.minx)/self.source_res), round((query.maxy - query.miny)/self.source_res))
        data = arcpy.RasterToNumPyArray(self.layer_name, lower_left, *source_query_cell_size)   
        data = torch.tensor(data)

        if self.res != self.source_res:
            target_query_cell_size = (round((query.maxx - query.minx)/self.res), round((query.maxy - query.miny)/self.res))
            resize = torchvision.transforms.Resize(target_query_cell_size)
            data = resize(data)

        sample = {"image": data, "crs": self.crs, "bbox": query}

        return sample

testtool2.pyt (ArcGIS Python Toolbox script):

#from testtool2 import *
# -*- coding: utf-8 -*-

from torchgeo.samplers import GridGeoSampler
from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
from arcgisdatasets import ArcGisRasterDataset
import torch
import itertools
import arcpy

class Toolbox(object):
    def __init__(self):
        """Define the toolbox (the name of the toolbox is the name of the
        .pyt file)."""
        self.label = "Toolbox"
        self.alias = "toolbox"

        # List of tool classes associated with this toolbox
        self.tools = [Tool]

class Tool(object):
    def __init__(self):
        """Define the tool (tool name is the name of the class)."""
        self.label = "Tool"
        self.description = ""
        self.canRunInBackground = False

    def getParameterInfo(self):
        """Define parameter definitions"""
        # First parameter
        param0 = arcpy.Parameter(
            displayName="Input Raster A",
            name="in_raster_a",
            datatype="GPRasterLayer",
            parameterType="Required",
            direction="Input")

        # Second parameter
        param1 = arcpy.Parameter(
            displayName="Input Raster B",
            name="in_raster_b",
            datatype="GPRasterLayer",
            parameterType="Required",
            direction="Input")

        # Third parameter
        param2 = arcpy.Parameter(
            displayName="Output Features",
            name="out_features",
            datatype="GPFeatureLayer",
            parameterType="Required",
            direction="Output")

        params = [param0, param1, param2]

        return params

    def isLicensed(self):
        """Set whether tool is licensed to execute."""
        return True

    def updateParameters(self, parameters):
        """Modify the values and properties of parameters before internal
        validation is performed.  This method is called whenever a parameter
        has been changed."""
        return

    def updateMessages(self, parameters):
        """Modify the messages created by internal validation for each tool
        parameter.  This method is called after internal validation."""
        return

    def execute(self, parameters, messages):
        """The source code of the tool."""
        batch_size = 8

        raster_a = ArcGisRasterDataset(parameters[0].valueAsText, res=0.2)
        raster_b = ArcGisRasterDataset(parameters[1].valueAsText, res=0.2)

        dataset = raster_a & raster_b

        gs = GridGeoSampler(dataset, 1024, 1024)

        # num_workers > 0 not working
        dataloader = DataLoader(dataset, sampler=gs, collate_fn=stack_samples, batch_size=batch_size, num_workers=0)

        arcpy.env.compression = "LZW"

        for batch in dataloader:
            batch_np = batch["image"].numpy()

            # ToDo: Inference

            # Convert input batch back to arcgis raster
            batch_rasters = [arcpy.NumPyArrayToRaster(batch_np[i], lower_left_corner=arcpy.Point(batch["bbox"][i].minx, batch["bbox"][i].miny), x_cell_size=dataset.res) for i in range(batch_size)]

            # Assemble batch rasters to a single raster and write to results folder
            arcpy.management.MosaicToNewRaster(batch_rasters, parameters[2].valueAsText, "%d.tif" % i, coordinate_system_for_the_raster=arcpy.Describe(parameters[0].valueAsText).spatialReference, number_of_bands=batch_np.shape[1])

        return
joshuasundance-swca commented 1 year ago

I would be interested in collaborating on any interface between ArcGIS Pro and torchgeo. Has any more work been done on this or related subjects?

adamjstewart commented 1 year ago

I think the next step is figuring out how all of this fits into TorchGeo. For arcgisdatasets.py it's obvious where to put it. testtool2.pyt is a little less obvious to me. Is this something that should be installed when you run pip install torchgeo? Or is this just an example of usage of the tool? If the latter, we'll definitely want to write a Jupyter Notebook tutorial to add to the docs to show how to use it.

I'm unfortunately ignorant of the inner workings of ArcGIS (I mostly use QGIS but have never written tools for either) so I'm relying on y'all to know what's needed and how it should work. I think I actually have a license for ArcGIS through the university, so I'll be the perfect guinea pig to test the effectiveness of the tutorial we add.

Bonus points if anyone out there is familiar with QGIS and wants to write something similar for that.

joshuasundance-swca commented 1 year ago

testtool2.pyt is designed to be interpreted by ArcGIS Pro. This file is used to to embed a nice little GUI interface for the custom tool in ArcGIS Pro. This code does not necessarily need to be packaged with torchgeo, but it should be made available to torchgeo users who also use ArcGIS Pro. Maybe an optional pip install torchgeo[arcgis] install, but the user will need to import testtool2.pyt into their ArcGIS project or configure their environment to import it by default.

I see two opportunities for interoperability:

  1. Using models trained in torchgeo to perform inferencing in ArcGIS Pro
  2. Using models trained in ArcGIS Pro to perform inferencing in torchgeo

Personally, I would be excited about both, but mainly # 2. The mechanics and logistics for both would be found in the links and wikis in the following repos:

For guidance, we could look to the output of model training in ArcGIS Pro. Generally speaking the process yields the following:

  1. the trained model as a pth or pb file (maybe more? I don't know)
  2. an esri model definition (json, *.emd) file: configuration and documentation; framework info, expected cell size, stuff like that
  3. a model configuration file (python): class with predefined structure that allows external training? never used it directly
  4. inferencing raster function (python): shows how the image is indexed, iterated over, preprocessed, inferenced on, etc
  5. a dlpk (deep learning package) file, which seems to be a zip file containing the items listed above

The raster function uses the updatePixels method on the pixelBlocks object.

ArcGIS Pro and torchgeo both have strengths and weaknesses, and I think that having the best of both worlds would be great for research and productivity. ArcGIS Pro is great for interactively creating training data, and it is good for training models even with multiple GPUs, but inferencing at scale can be a struggle. I think it would be great, for example, to take a model trained in ArcGIS Pro and then deploy it for inferencing using torchgeo. This would also make the models more portable and useful across platforms, as ArcGIS Pro cannot be used in Linux environments. I have often thought about working on bridging the two, but I think it may be challenging. Any thoughts, guidance, or discussion would be appreciated.

adamjstewart commented 1 year ago

ArcGIS Pro ... is good for training models even with multiple GPUs

For the record, TorchGeo can also do this using PyTorch Lightning. Lightning supports distributed training on multiple nodes, or on multiple GPUs on the same node. And training metrics can be automatically logged to TensorBoard. But I agree that the interactivity of ArcGIS is nice.

This all sounds great, and thanks for the documentation links! I think there's definitely a lot of possibilities for integration here. Let's start small and get the basic, most commonly used things working, and expand from there. The dataset defined above can be added to torchgeo/datasets. For the .pyt file, where does ArcGIS search for these files? Is there some environment variable that needs to be set? If it's small enough we can also just store it in the documentation. For most other things, creating a new torchgeo/utils directory may make sense.

@tritolol would you like to submit a PR to add your preliminary code to TorchGeo? We can iterate from there.

tritolol commented 1 year ago

Its been a while now sice I worked on this. However, I managed to get a torchgeo-based inference loop running eventually. The developlemt of the toolbox was quite painful since I didn't find a convenient way to debug. I think I had to use Visual Studio (without "Code") to get breakpoints to work and had to reload the toolbox inside ArcGIS Pro every time I made a change. This seems to be an issue for others as well: https://community.esri.com/t5/arcgis-pro-questions/debugging-a-python-toolbox-using-visual-studio/td-p/1079955 As far as I know, the toolbox script must be in the folder of the ArcGIS project you want to use it in, but there might be other options.

Before continuing integration we should find a more convenient way to develop toolboxes. Also, I don't really have the time at the moment to pursue this further. Anyway, this is the final state of testtool2.pyt:

#from testtool2 import *
# -*- coding: utf-8 -*-

from torchgeo.datasets.geo import GeoDataset
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast
from torchgeo.datasets.utils import BoundingBox
from rasterio.crs import CRS
import torch
import torchvision
import sys

from torch.nn.modules import Module
from torch import Tensor

from torchgeo.samplers import GridGeoSampler
from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
import torch
import itertools
import arcpy
import os

import segmentation_models_pytorch as smp

class NoDataReplacer(Module):
    def __init__(self, nodata_value: float, replace_with: float) -> None:
        """Initialize a new transform instance.

        Args:
            index_a: reference band channel index
            index_b: difference band channel index
        """
        super().__init__()
        self.nodat = nodata_value
        self.replace = replace_with

    def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
        sample["image"][sample["image"] == self.nodat] = self.replace

        return sample

class ArcGisRasterDataset(GeoDataset):
    layer_name = ""

    def __init__(
        self,
        layer_name,
        crs: Optional[CRS] = None,
        res: Optional[float] = None,
        transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
    ) -> None:

        super().__init__(transforms)

        self.layer_name = layer_name

        top = float(arcpy.GetRasterProperties_management(self.layer_name, "TOP").getOutput(0).replace(",", "."))
        bottom = float(arcpy.GetRasterProperties_management(self.layer_name, "BOTTOM").getOutput(0).replace(",", "."))
        left = float(arcpy.GetRasterProperties_management(self.layer_name, "LEFT").getOutput(0).replace(",", "."))
        right = float(arcpy.GetRasterProperties_management(self.layer_name, "RIGHT").getOutput(0).replace(",", "."))

        coords = (left, right, bottom, top, 0, sys.maxsize)
        self.extent = coords
        self.index.insert(0, coords)

        self.source_res = float(arcpy.GetRasterProperties_management(self.layer_name, "CELLSIZEX").getOutput(0).replace(",", "."))

        if res == None:
            self.res = self.source_res
        else:
            self.res = res

        self.bands = int(arcpy.GetRasterProperties_management(self.layer_name, "BANDCOUNT").getOutput(0))
        #arcpy.AddMessage(self.bands)

        description = arcpy.Describe(self.layer_name)
        #self.nodata = description.noDataValue
        self._crs = CRS.from_epsg(description.spatialReference.factoryCode)

        self.transforms = transforms

    def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
        lower_left = arcpy.Point(query.minx, query.miny)    
        source_query_cell_size = (round((query.maxx - query.minx)/self.source_res), round((query.maxy - query.miny)/self.source_res))
        data = arcpy.RasterToNumPyArray(self.layer_name, lower_left, *source_query_cell_size)

        data = torch.tensor(data)
        if self.bands == 1:
            data = data.unsqueeze(0)

        #arcpy.AddMessage(data.dtype)

        sample = {"image": data, "crs": self.crs, "bbox": query}

        if self.transforms != None:
            sample = self.transforms(sample)

        if self.res != self.source_res:
            target_query_cell_size = (round((query.maxx - query.minx)/self.res), round((query.maxy - query.miny)/self.res))
            resize = torchvision.transforms.Resize(target_query_cell_size)
            data = sample["image"]
            data = resize(data)

        sample = {"image": data, "crs": self.crs, "bbox": query}

        return sample

class Toolbox(object):
    def __init__(self):
        """Define the toolbox (the name of the toolbox is the name of the
        .pyt file)."""
        self.label = "Toolbox"
        self.alias = "toolbox"

        # List of tool classes associated with this toolbox
        self.tools = [Tool]

class Tool(object):
    def __init__(self):
        """Define the tool (tool name is the name of the class)."""
        self.label = "Tool"
        self.description = ""
        self.canRunInBackground = False
        #self.params = arcpy.GetParameterInfo()

    def getParameterInfo(self):
        """Define parameter definitions"""
        # First parameter
        param0 = arcpy.Parameter(
            displayName="Input Raster A",
            name="in_raster_a",
            datatype="GPRasterLayer",
            parameterType="Required",
            direction="Input")

        # Second parameter
        param1 = arcpy.Parameter(
            displayName="Input Raster B",
            name="in_raster_b",
            datatype="GPRasterLayer",
            parameterType="Optional",
            direction="Input")

        # Third parameter
        param2 = arcpy.Parameter(
            displayName="Output Tiles",
            name="out_tiles",
            datatype="DEFolder",
            parameterType="Required",
            direction="Output")

        param3 = arcpy.Parameter(
            displayName="Input data type",
            name="in_value",
            datatype="GPString",
            parameterType="Required",
            direction="Input")

        param4 = arcpy.Parameter(
            displayName="Model",
            name="in_model",
            datatype="DEFile",
            parameterType="Required",
            direction="Input")

        param5 = arcpy.Parameter(
            displayName="Batch Size",
            name="in_bs",
            datatype="GPLong",
            parameterType="Required",
            direction="Input")

        param6 = arcpy.Parameter(
            displayName="Tile Aggregation Size",
            name="in_as",
            datatype="GPLong",
            parameterType="Required",
            direction="Input")

        # Set a value list of 1, 10 and 100
        param3.filter.type = "ValueList"
        param3.filter.list = ["DOM A, DOM B", "tDOM"]

        param4.filter.list = ['pt']

        param5.filter.type = "ValueList"
        param5.filter.list = list(range(1, 17))

        param6.filter.type = "ValueList"
        param6.filter.list = list(range(1, 101))

        params = [param3, param0, param1, param2, param4, param5, param6]

        return params

    #def initializeParameters(self, parameters):
        #parameters[1].enabled = False

    def isLicensed(self):
        """Set whether tool is licensed to execute."""
        return True

    def updateParameters(self, parameters):
        """Modify the values and properties of parameters before internal
        validation is performed.  This method is called whenever a parameter
        has been changed."""

        if parameters[0].value == "tDOM":
            parameters[1].enabled = True
            parameters[2].enabled = False
        elif parameters[0].value == "DOM A, DOM B":
            parameters[1].enabled = True
            parameters[2].enabled = True
        else:
            parameters[1].enabled = False
            parameters[2].enabled = False

        return

    def updateMessages(self, parameters):
        """Modify the messages created by internal validation for each tool
        parameter.  This method is called after internal validation."""
        return

    def execute(self, parameters, messages):
        """The source code of the tool."""
        batch_size = parameters[5].value

        if parameters[0].valueAsText == "tDOM":
            description = arcpy.Describe(parameters[1].valueAsText)
            nodat_repl = NoDataReplacer(description.noDataValue, 0.0)
            dataset = ArcGisRasterDataset(parameters[1].valueAsText, res=0.2, transforms=nodat_repl)
            if round(dataset.source_res, 3) > 0.2:
                raise ValueError('A source dataset\'s resolution is too low. Minimum ground sampling distance: 0.2m')
        elif parameters[0].value == "DOM A, DOM B":
            description = arcpy.Describe(parameters[1].valueAsText)
            nodat_repl = NoDataReplacer(description.noDataValue, 0.0)
            raster_a = ArcGisRasterDataset(parameters[1].valueAsText, res=0.2, transforms=nodat_repl)
            if raster_a.source_res > 0.2:
                raise ValueError('A source dataset\'s resolution is too low. Minimum ground sampling distance: 0.2m')

            description = arcpy.Describe(parameters[2].valueAsText)
            nodat_repl = NoDataReplacer(description.noDataValue, 0.0)
            raster_b = ArcGisRasterDataset(parameters[2].valueAsText, res=0.2, transforms=nodat_repl)
            if raster_b.source_res > 0.2:
                raise ValueError('A source dataset\'s resolution is too low. Minimum ground sampling distance: 0.2m')

            dataset = raster_a & raster_b

        gs = GridGeoSampler(dataset, 1024, 1024)

        dataloader = DataLoader(dataset, sampler=gs, collate_fn=stack_samples, batch_size=batch_size, num_workers=0)

        arcpy.env.compression = "LZW"

        model_name = parameters[4].valueAsText
        model = smp.Unet(encoder_name="resnet101", encoder_weights="imagenet", in_channels=1).cuda()

        checkpoint = torch.load(model_name)
        fixed_state_dict = {}
        for k, v in checkpoint['model_state_dict'].items():
            k = k.replace("module.", "")
            fixed_state_dict[k] = v
        model.load_state_dict(fixed_state_dict)

        model.eval()

        os.mkdir(parameters[3].valueAsText)

        arcpy.AddMessage("Processing %d batches..." % len(dataloader))
        arcpy.SetProgressor("step", "Processing %d batches..." % len(dataloader), 0, len(dataloader), 1)

        batch_rasters = []
        j = 0
        for i, batch in enumerate(dataloader):
            arcpy.SetProgressorPosition(i + 1)
            arcpy.SetProgressorLabel("Processing batch: {}".format(i + 1))
            batch_data = batch["image"]
            if parameters[0].value == "DOM A, DOM B":
                batch_data = (batch_data[:, 1, :, :] - batch_data[:, 0, :, :]).unsqueeze(1)     # calculate tDOM

            with torch.no_grad():
                y = model(batch_data.cuda())

            y = torch.nn.functional.relu(y).cpu().numpy()
            #arcpy.AddMessage(y.min())
            #arcpy.AddMessage(y.max())
            # Convert input batch back to arcgis raster
            batch_rasters += [arcpy.NumPyArrayToRaster(y[i], lower_left_corner=arcpy.Point(batch["bbox"][i].minx, batch["bbox"][i].miny), x_cell_size=dataset.res, value_to_nodata=0) for i in range(batch_size)]
            # Assemble batch rasters to a single raster and write to results folder

            if (i + 1) % parameters[6].value == 0:
                arcpy.SetProgressorLabel("Writing tile to disk...")
                arcpy.management.MosaicToNewRaster(batch_rasters, parameters[3].valueAsText, "%d.tif" % j, coordinate_system_for_the_raster=arcpy.Describe(parameters[1].valueAsText).spatialReference, number_of_bands=1, pixel_type="32_BIT_FLOAT")
                batch_rasters = []
                j += 1

        if batch_rasters:
            arcpy.SetProgressorLabel("Writing tile to disk...")
            arcpy.management.MosaicToNewRaster(batch_rasters, parameters[3].valueAsText, "%d.tif" % j, coordinate_system_for_the_raster=arcpy.Describe(parameters[1].valueAsText).spatialReference, number_of_bands=1, pixel_type="32_BIT_FLOAT")

        return

It is used to perform building change detection in DSM raster data. There are two options to input data: DSM from time point A and DSM from time point B, or differential DSM (tDOM in the code). If you input two DSMs the dDSM is calculated internally and then input to a segmentation model to do inference. Multiple output tiles can be stitched together before saving them to an output directory.

I submit a PR for this if you like, but as mentioned I can't contribute much at the moment.

adamjstewart commented 1 year ago

Where are we at with this? Would love to have official ArcGIS support.