Open tritolol opened 2 years ago
Could you provide a gist of the implementation?
Would you be interested in including this into torchgeo?
Yes!
Could you provide a gist of the implementation?
Sure arcgisdatasets.py:
import arcpy
from torchgeo.datasets.geo import GeoDataset
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast
from torchgeo.datasets.utils import BoundingBox
from rasterio.crs import CRS
import torch
import torchvision
import sys
class ArcGisRasterDataset(GeoDataset):
layer_name = ""
def __init__(
self,
layer_name,
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
) -> None:
super().__init__(transforms)
self.layer_name = layer_name
top = float(arcpy.GetRasterProperties_management(self.layer_name, "TOP").getOutput(0).replace(",", "."))
bottom = float(arcpy.GetRasterProperties_management(self.layer_name, "BOTTOM").getOutput(0).replace(",", "."))
left = float(arcpy.GetRasterProperties_management(self.layer_name, "LEFT").getOutput(0).replace(",", "."))
right = float(arcpy.GetRasterProperties_management(self.layer_name, "RIGHT").getOutput(0).replace(",", "."))
coords = (left, right, bottom, top, 0, sys.maxsize)
self.extent = coords
self.index.insert(0, coords)
self.source_res = float(arcpy.GetRasterProperties_management(self.layer_name, "CELLSIZEX").getOutput(0).replace(",", "."))
if res == None:
self.res = self.source_res
else:
self.res = res
spatial_ref = arcpy.Describe(self.layer_name).spatialReference.factoryCode
self._crs = CRS.from_epsg(spatial_ref)
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
lower_left = arcpy.Point(query.minx, query.miny)
source_query_cell_size = (round((query.maxx - query.minx)/self.source_res), round((query.maxy - query.miny)/self.source_res))
data = arcpy.RasterToNumPyArray(self.layer_name, lower_left, *source_query_cell_size)
data = torch.tensor(data)
if self.res != self.source_res:
target_query_cell_size = (round((query.maxx - query.minx)/self.res), round((query.maxy - query.miny)/self.res))
resize = torchvision.transforms.Resize(target_query_cell_size)
data = resize(data)
sample = {"image": data, "crs": self.crs, "bbox": query}
return sample
testtool2.pyt (ArcGIS Python Toolbox script):
#from testtool2 import *
# -*- coding: utf-8 -*-
from torchgeo.samplers import GridGeoSampler
from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
from arcgisdatasets import ArcGisRasterDataset
import torch
import itertools
import arcpy
class Toolbox(object):
def __init__(self):
"""Define the toolbox (the name of the toolbox is the name of the
.pyt file)."""
self.label = "Toolbox"
self.alias = "toolbox"
# List of tool classes associated with this toolbox
self.tools = [Tool]
class Tool(object):
def __init__(self):
"""Define the tool (tool name is the name of the class)."""
self.label = "Tool"
self.description = ""
self.canRunInBackground = False
def getParameterInfo(self):
"""Define parameter definitions"""
# First parameter
param0 = arcpy.Parameter(
displayName="Input Raster A",
name="in_raster_a",
datatype="GPRasterLayer",
parameterType="Required",
direction="Input")
# Second parameter
param1 = arcpy.Parameter(
displayName="Input Raster B",
name="in_raster_b",
datatype="GPRasterLayer",
parameterType="Required",
direction="Input")
# Third parameter
param2 = arcpy.Parameter(
displayName="Output Features",
name="out_features",
datatype="GPFeatureLayer",
parameterType="Required",
direction="Output")
params = [param0, param1, param2]
return params
def isLicensed(self):
"""Set whether tool is licensed to execute."""
return True
def updateParameters(self, parameters):
"""Modify the values and properties of parameters before internal
validation is performed. This method is called whenever a parameter
has been changed."""
return
def updateMessages(self, parameters):
"""Modify the messages created by internal validation for each tool
parameter. This method is called after internal validation."""
return
def execute(self, parameters, messages):
"""The source code of the tool."""
batch_size = 8
raster_a = ArcGisRasterDataset(parameters[0].valueAsText, res=0.2)
raster_b = ArcGisRasterDataset(parameters[1].valueAsText, res=0.2)
dataset = raster_a & raster_b
gs = GridGeoSampler(dataset, 1024, 1024)
# num_workers > 0 not working
dataloader = DataLoader(dataset, sampler=gs, collate_fn=stack_samples, batch_size=batch_size, num_workers=0)
arcpy.env.compression = "LZW"
for batch in dataloader:
batch_np = batch["image"].numpy()
# ToDo: Inference
# Convert input batch back to arcgis raster
batch_rasters = [arcpy.NumPyArrayToRaster(batch_np[i], lower_left_corner=arcpy.Point(batch["bbox"][i].minx, batch["bbox"][i].miny), x_cell_size=dataset.res) for i in range(batch_size)]
# Assemble batch rasters to a single raster and write to results folder
arcpy.management.MosaicToNewRaster(batch_rasters, parameters[2].valueAsText, "%d.tif" % i, coordinate_system_for_the_raster=arcpy.Describe(parameters[0].valueAsText).spatialReference, number_of_bands=batch_np.shape[1])
return
I would be interested in collaborating on any interface between ArcGIS Pro and torchgeo. Has any more work been done on this or related subjects?
I think the next step is figuring out how all of this fits into TorchGeo. For arcgisdatasets.py
it's obvious where to put it. testtool2.pyt
is a little less obvious to me. Is this something that should be installed when you run pip install torchgeo
? Or is this just an example of usage of the tool? If the latter, we'll definitely want to write a Jupyter Notebook tutorial to add to the docs to show how to use it.
I'm unfortunately ignorant of the inner workings of ArcGIS (I mostly use QGIS but have never written tools for either) so I'm relying on y'all to know what's needed and how it should work. I think I actually have a license for ArcGIS through the university, so I'll be the perfect guinea pig to test the effectiveness of the tutorial we add.
Bonus points if anyone out there is familiar with QGIS and wants to write something similar for that.
testtool2.pyt
is designed to be interpreted by ArcGIS Pro. This file is used to to embed a nice little GUI interface for the custom tool in ArcGIS Pro. This code does not necessarily need to be packaged with torchgeo, but it should be made available to torchgeo users who also use ArcGIS Pro. Maybe an optional pip install torchgeo[arcgis]
install, but the user will need to import testtool2.pyt
into their ArcGIS project or configure their environment to import it by default.
I see two opportunities for interoperability:
Personally, I would be excited about both, but mainly # 2. The mechanics and logistics for both would be found in the links and wikis in the following repos:
For guidance, we could look to the output of model training in ArcGIS Pro. Generally speaking the process yields the following:
The raster function uses the updatePixels method on the pixelBlocks object.
ArcGIS Pro and torchgeo both have strengths and weaknesses, and I think that having the best of both worlds would be great for research and productivity. ArcGIS Pro is great for interactively creating training data, and it is good for training models even with multiple GPUs, but inferencing at scale can be a struggle. I think it would be great, for example, to take a model trained in ArcGIS Pro and then deploy it for inferencing using torchgeo. This would also make the models more portable and useful across platforms, as ArcGIS Pro cannot be used in Linux environments. I have often thought about working on bridging the two, but I think it may be challenging. Any thoughts, guidance, or discussion would be appreciated.
ArcGIS Pro ... is good for training models even with multiple GPUs
For the record, TorchGeo can also do this using PyTorch Lightning. Lightning supports distributed training on multiple nodes, or on multiple GPUs on the same node. And training metrics can be automatically logged to TensorBoard. But I agree that the interactivity of ArcGIS is nice.
This all sounds great, and thanks for the documentation links! I think there's definitely a lot of possibilities for integration here. Let's start small and get the basic, most commonly used things working, and expand from there. The dataset defined above can be added to torchgeo/datasets
. For the .pyt
file, where does ArcGIS search for these files? Is there some environment variable that needs to be set? If it's small enough we can also just store it in the documentation. For most other things, creating a new torchgeo/utils
directory may make sense.
@tritolol would you like to submit a PR to add your preliminary code to TorchGeo? We can iterate from there.
Its been a while now sice I worked on this. However, I managed to get a torchgeo-based inference loop running eventually. The developlemt of the toolbox was quite painful since I didn't find a convenient way to debug. I think I had to use Visual Studio (without "Code") to get breakpoints to work and had to reload the toolbox inside ArcGIS Pro every time I made a change. This seems to be an issue for others as well: https://community.esri.com/t5/arcgis-pro-questions/debugging-a-python-toolbox-using-visual-studio/td-p/1079955 As far as I know, the toolbox script must be in the folder of the ArcGIS project you want to use it in, but there might be other options.
Before continuing integration we should find a more convenient way to develop toolboxes. Also, I don't really have the time at the moment to pursue this further. Anyway, this is the final state of testtool2.pyt:
#from testtool2 import *
# -*- coding: utf-8 -*-
from torchgeo.datasets.geo import GeoDataset
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast
from torchgeo.datasets.utils import BoundingBox
from rasterio.crs import CRS
import torch
import torchvision
import sys
from torch.nn.modules import Module
from torch import Tensor
from torchgeo.samplers import GridGeoSampler
from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
import torch
import itertools
import arcpy
import os
import segmentation_models_pytorch as smp
class NoDataReplacer(Module):
def __init__(self, nodata_value: float, replace_with: float) -> None:
"""Initialize a new transform instance.
Args:
index_a: reference band channel index
index_b: difference band channel index
"""
super().__init__()
self.nodat = nodata_value
self.replace = replace_with
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
sample["image"][sample["image"] == self.nodat] = self.replace
return sample
class ArcGisRasterDataset(GeoDataset):
layer_name = ""
def __init__(
self,
layer_name,
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
) -> None:
super().__init__(transforms)
self.layer_name = layer_name
top = float(arcpy.GetRasterProperties_management(self.layer_name, "TOP").getOutput(0).replace(",", "."))
bottom = float(arcpy.GetRasterProperties_management(self.layer_name, "BOTTOM").getOutput(0).replace(",", "."))
left = float(arcpy.GetRasterProperties_management(self.layer_name, "LEFT").getOutput(0).replace(",", "."))
right = float(arcpy.GetRasterProperties_management(self.layer_name, "RIGHT").getOutput(0).replace(",", "."))
coords = (left, right, bottom, top, 0, sys.maxsize)
self.extent = coords
self.index.insert(0, coords)
self.source_res = float(arcpy.GetRasterProperties_management(self.layer_name, "CELLSIZEX").getOutput(0).replace(",", "."))
if res == None:
self.res = self.source_res
else:
self.res = res
self.bands = int(arcpy.GetRasterProperties_management(self.layer_name, "BANDCOUNT").getOutput(0))
#arcpy.AddMessage(self.bands)
description = arcpy.Describe(self.layer_name)
#self.nodata = description.noDataValue
self._crs = CRS.from_epsg(description.spatialReference.factoryCode)
self.transforms = transforms
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
lower_left = arcpy.Point(query.minx, query.miny)
source_query_cell_size = (round((query.maxx - query.minx)/self.source_res), round((query.maxy - query.miny)/self.source_res))
data = arcpy.RasterToNumPyArray(self.layer_name, lower_left, *source_query_cell_size)
data = torch.tensor(data)
if self.bands == 1:
data = data.unsqueeze(0)
#arcpy.AddMessage(data.dtype)
sample = {"image": data, "crs": self.crs, "bbox": query}
if self.transforms != None:
sample = self.transforms(sample)
if self.res != self.source_res:
target_query_cell_size = (round((query.maxx - query.minx)/self.res), round((query.maxy - query.miny)/self.res))
resize = torchvision.transforms.Resize(target_query_cell_size)
data = sample["image"]
data = resize(data)
sample = {"image": data, "crs": self.crs, "bbox": query}
return sample
class Toolbox(object):
def __init__(self):
"""Define the toolbox (the name of the toolbox is the name of the
.pyt file)."""
self.label = "Toolbox"
self.alias = "toolbox"
# List of tool classes associated with this toolbox
self.tools = [Tool]
class Tool(object):
def __init__(self):
"""Define the tool (tool name is the name of the class)."""
self.label = "Tool"
self.description = ""
self.canRunInBackground = False
#self.params = arcpy.GetParameterInfo()
def getParameterInfo(self):
"""Define parameter definitions"""
# First parameter
param0 = arcpy.Parameter(
displayName="Input Raster A",
name="in_raster_a",
datatype="GPRasterLayer",
parameterType="Required",
direction="Input")
# Second parameter
param1 = arcpy.Parameter(
displayName="Input Raster B",
name="in_raster_b",
datatype="GPRasterLayer",
parameterType="Optional",
direction="Input")
# Third parameter
param2 = arcpy.Parameter(
displayName="Output Tiles",
name="out_tiles",
datatype="DEFolder",
parameterType="Required",
direction="Output")
param3 = arcpy.Parameter(
displayName="Input data type",
name="in_value",
datatype="GPString",
parameterType="Required",
direction="Input")
param4 = arcpy.Parameter(
displayName="Model",
name="in_model",
datatype="DEFile",
parameterType="Required",
direction="Input")
param5 = arcpy.Parameter(
displayName="Batch Size",
name="in_bs",
datatype="GPLong",
parameterType="Required",
direction="Input")
param6 = arcpy.Parameter(
displayName="Tile Aggregation Size",
name="in_as",
datatype="GPLong",
parameterType="Required",
direction="Input")
# Set a value list of 1, 10 and 100
param3.filter.type = "ValueList"
param3.filter.list = ["DOM A, DOM B", "tDOM"]
param4.filter.list = ['pt']
param5.filter.type = "ValueList"
param5.filter.list = list(range(1, 17))
param6.filter.type = "ValueList"
param6.filter.list = list(range(1, 101))
params = [param3, param0, param1, param2, param4, param5, param6]
return params
#def initializeParameters(self, parameters):
#parameters[1].enabled = False
def isLicensed(self):
"""Set whether tool is licensed to execute."""
return True
def updateParameters(self, parameters):
"""Modify the values and properties of parameters before internal
validation is performed. This method is called whenever a parameter
has been changed."""
if parameters[0].value == "tDOM":
parameters[1].enabled = True
parameters[2].enabled = False
elif parameters[0].value == "DOM A, DOM B":
parameters[1].enabled = True
parameters[2].enabled = True
else:
parameters[1].enabled = False
parameters[2].enabled = False
return
def updateMessages(self, parameters):
"""Modify the messages created by internal validation for each tool
parameter. This method is called after internal validation."""
return
def execute(self, parameters, messages):
"""The source code of the tool."""
batch_size = parameters[5].value
if parameters[0].valueAsText == "tDOM":
description = arcpy.Describe(parameters[1].valueAsText)
nodat_repl = NoDataReplacer(description.noDataValue, 0.0)
dataset = ArcGisRasterDataset(parameters[1].valueAsText, res=0.2, transforms=nodat_repl)
if round(dataset.source_res, 3) > 0.2:
raise ValueError('A source dataset\'s resolution is too low. Minimum ground sampling distance: 0.2m')
elif parameters[0].value == "DOM A, DOM B":
description = arcpy.Describe(parameters[1].valueAsText)
nodat_repl = NoDataReplacer(description.noDataValue, 0.0)
raster_a = ArcGisRasterDataset(parameters[1].valueAsText, res=0.2, transforms=nodat_repl)
if raster_a.source_res > 0.2:
raise ValueError('A source dataset\'s resolution is too low. Minimum ground sampling distance: 0.2m')
description = arcpy.Describe(parameters[2].valueAsText)
nodat_repl = NoDataReplacer(description.noDataValue, 0.0)
raster_b = ArcGisRasterDataset(parameters[2].valueAsText, res=0.2, transforms=nodat_repl)
if raster_b.source_res > 0.2:
raise ValueError('A source dataset\'s resolution is too low. Minimum ground sampling distance: 0.2m')
dataset = raster_a & raster_b
gs = GridGeoSampler(dataset, 1024, 1024)
dataloader = DataLoader(dataset, sampler=gs, collate_fn=stack_samples, batch_size=batch_size, num_workers=0)
arcpy.env.compression = "LZW"
model_name = parameters[4].valueAsText
model = smp.Unet(encoder_name="resnet101", encoder_weights="imagenet", in_channels=1).cuda()
checkpoint = torch.load(model_name)
fixed_state_dict = {}
for k, v in checkpoint['model_state_dict'].items():
k = k.replace("module.", "")
fixed_state_dict[k] = v
model.load_state_dict(fixed_state_dict)
model.eval()
os.mkdir(parameters[3].valueAsText)
arcpy.AddMessage("Processing %d batches..." % len(dataloader))
arcpy.SetProgressor("step", "Processing %d batches..." % len(dataloader), 0, len(dataloader), 1)
batch_rasters = []
j = 0
for i, batch in enumerate(dataloader):
arcpy.SetProgressorPosition(i + 1)
arcpy.SetProgressorLabel("Processing batch: {}".format(i + 1))
batch_data = batch["image"]
if parameters[0].value == "DOM A, DOM B":
batch_data = (batch_data[:, 1, :, :] - batch_data[:, 0, :, :]).unsqueeze(1) # calculate tDOM
with torch.no_grad():
y = model(batch_data.cuda())
y = torch.nn.functional.relu(y).cpu().numpy()
#arcpy.AddMessage(y.min())
#arcpy.AddMessage(y.max())
# Convert input batch back to arcgis raster
batch_rasters += [arcpy.NumPyArrayToRaster(y[i], lower_left_corner=arcpy.Point(batch["bbox"][i].minx, batch["bbox"][i].miny), x_cell_size=dataset.res, value_to_nodata=0) for i in range(batch_size)]
# Assemble batch rasters to a single raster and write to results folder
if (i + 1) % parameters[6].value == 0:
arcpy.SetProgressorLabel("Writing tile to disk...")
arcpy.management.MosaicToNewRaster(batch_rasters, parameters[3].valueAsText, "%d.tif" % j, coordinate_system_for_the_raster=arcpy.Describe(parameters[1].valueAsText).spatialReference, number_of_bands=1, pixel_type="32_BIT_FLOAT")
batch_rasters = []
j += 1
if batch_rasters:
arcpy.SetProgressorLabel("Writing tile to disk...")
arcpy.management.MosaicToNewRaster(batch_rasters, parameters[3].valueAsText, "%d.tif" % j, coordinate_system_for_the_raster=arcpy.Describe(parameters[1].valueAsText).spatialReference, number_of_bands=1, pixel_type="32_BIT_FLOAT")
return
It is used to perform building change detection in DSM raster data. There are two options to input data: DSM from time point A and DSM from time point B, or differential DSM (tDOM in the code). If you input two DSMs the dDSM is calculated internally and then input to a segmentation model to do inference. Multiple output tiles can be stitched together before saving them to an output directory.
I submit a PR for this if you like, but as mentioned I can't contribute much at the moment.
Where are we at with this? Would love to have official ArcGIS support.
Hi, I'm currently working on an interface between ESRI ArcGIS Pro and torchgeo. The basic functionality is done so far, I'm able to sample raster layers directly off a running ArcGIS application using torchgeo samplers. Technically, this is just a new dataset class inheriting from GeoDataset, currently only about 50 lines of code. At the moment I'm planning to use this only for inference purposes, since training within an ArcGIS Pro instance isn't feasible imo.
The motivation behind writing this interface is to provide city administration staff who work with this software a simple way of utilizing arbitrary pretrained pytorch models. Currently, I'm using a python toolbox script to let the user define input and output layers which contains just a few lines of code, which is a nice thing.
I'm aware of the Deep Learning toolbox in ESRI's Image Analyst package. However, I didn't find an easy way to incorporate custom trained models into this workflow.
Would you be interested in including this into torchgeo?