Add a plot method to IntersectionDataset that calls the plot method of all GeoDatasets and merges their output into a single figure.
Rationale
TorchGeo currently relies on the following fundamental design principle:
Each dataset has a plot method that returns a single Figure
Each data module has a plot method that calls dataset.plot and returns a single Figure
Each trainer calls this plot method during validation and logs a single Figure to tensorboard
This works great for most NonGeoDatasets and GeoDatasets, but falls apart as soon as you use an IntersectionDataset. Merging two Figures together is non-trivial, and so far we've simply ignored this use case by plotting the mask/prediction but not the image. This prevents proper visualization of training and makes IntersectionDatasets different from other GeoDatasets in terms of features.
Implementation
This idea can be implemented in multiple stages (possibly by multiple people).
First, we'll need to add an abstract plot method to GeoDataset. There are a few point datasets that do not yet have plot methods (EDDMapS, GBIF, INaturalist), so we'll have to add one. Might as well add one to NonGeoDataset as well to enforce consistency. Regardless, this is likely a good idea to do, it's just a prerequisite for what follows.
Second, we'll need to add a new parameter to all GeoDataset.plot methods containing an optional pointer. The idea is that normally, GeoDataset.plot will create its own Figure and returns it, but IntersectionDataset.plot will instead create a Figure and call each GeoDataset.plot method with the same sample and this additional pointer telling it which subplot to populate. Or maybe pass it an existing figure and have it select the next empty subplot to use? The number of subplots to create in IntersectionDataset.plot could either be based on the number of datasets (would exclude plotting prediction maps) or on the number of image/mask/prediction elements in the sample dict (relies on #985 if we want any guarantees). See Additional information for caveats.
Alternatives
We could instead change GeoDataset.plot and GeoDataModule.plot to return one or moreFigures and our trainers to plot a list of figures. This may actually be significantly simpler to implement. The only downside is that not all datasets will be consistent, and it would be nice to be able to display a single Figure with all information.
Additional information
Some remaining TODOs to figure out:
How will this work for UnionDataset?
How will this work for IntersectionDataset with multiple images (e.g., Landsat & Sentinel)?
This is related to #1263 and could likely be done in parallel.
Summary
Add a plot method to IntersectionDataset that calls the plot method of all GeoDatasets and merges their output into a single figure.
Rationale
TorchGeo currently relies on the following fundamental design principle:
plot
method that returns a singleFigure
plot
method that callsdataset.plot
and returns a singleFigure
plot
method during validation and logs a singleFigure
to tensorboardThis works great for most
NonGeoDatasets
andGeoDatasets
, but falls apart as soon as you use anIntersectionDataset
. Merging twoFigures
together is non-trivial, and so far we've simply ignored this use case by plotting the mask/prediction but not the image. This prevents proper visualization of training and makesIntersectionDatasets
different from otherGeoDatasets
in terms of features.Implementation
This idea can be implemented in multiple stages (possibly by multiple people).
First, we'll need to add an abstract
plot
method toGeoDataset
. There are a few point datasets that do not yet haveplot
methods (EDDMapS
,GBIF
,INaturalist
), so we'll have to add one. Might as well add one toNonGeoDataset
as well to enforce consistency. Regardless, this is likely a good idea to do, it's just a prerequisite for what follows.Second, we'll need to add a new parameter to all
GeoDataset.plot
methods containing an optional pointer. The idea is that normally,GeoDataset.plot
will create its ownFigure
and returns it, butIntersectionDataset.plot
will instead create aFigure
and call eachGeoDataset.plot
method with the samesample
and this additional pointer telling it which subplot to populate. Or maybe pass it an existing figure and have it select the next empty subplot to use? The number of subplots to create inIntersectionDataset.plot
could either be based on the number of datasets (would exclude plotting prediction maps) or on the number of image/mask/prediction elements in thesample
dict (relies on #985 if we want any guarantees). See Additional information for caveats.Alternatives
We could instead change
GeoDataset.plot
andGeoDataModule.plot
to return one or moreFigures
and our trainers to plot a list of figures. This may actually be significantly simpler to implement. The only downside is that not all datasets will be consistent, and it would be nice to be able to display a singleFigure
with all information.Additional information
Some remaining TODOs to figure out:
UnionDataset
?IntersectionDataset
with multiple images (e.g., Landsat & Sentinel)?This is related to #1263 and could likely be done in parallel.