microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.61k stars 322 forks source link

GridGeoSampler resamples same image repeatedly with separate_files and multiple dates #2221

Closed sfalkena closed 3 weeks ago

sfalkena commented 1 month ago

Description

Hi,

Whenever I try to create a dataset which uses separate files for each band and which has multiple dates available, I notice that the sampler produces a sample for every band and every date (for every file that intersects the dataset index). I could understand that we want a sample per date, but in the current approach the sampler dataset receives a BoundingBox for every band, leading to resampling the same spatial area over and over, leading to wasted compute energy during inference.

I have created a reproducable example using the Sentinel2 testfiles as an input. Right now I'm printing the sum of the tensor to show that the 2 different tensors are retrieved 4 times. Additionally, I agree with the fact that sampler.hits should have length 8, but len(sampler) should have length 2 in my opinion.

Could you tell me if this behavior is to be expected? I have been thinking on ways to go around this issue, and for sampling a dataset like Sentinel2 we could probably leverage the tile-id, but I would prefer to add a more generalizable option.

Steps to reproduce

from torchgeo.datasets import Sentinel2
from torchgeo.samplers import GridGeoSampler
from torch.utils.data import DataLoader
from torchgeo.datasets import stack_samples, unbind_samples

sentinel_urls = [
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20190414T110751_AOT_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20190414T110751_B02_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20190414T110751_B03_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20190414T110751_B04_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20190414T110751_B08_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20190414T110751_TCI_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20190414T110751_WVP_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20220414T110751_AOT_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20220414T110751_B02_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20220414T110751_B03_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20220414T110751_B04_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20220414T110751_B08_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20220414T110751_TCI_10m.jp2",
    "https://github.com/microsoft/torchgeo/raw/main/tests/data/sentinel2/S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE/GRANULE/L2A_T26EMU_A035569_20220414T110747/IMG_DATA/R10m/T26EMU_20220414T110751_WVP_10m.jp2",
]

dataset = Sentinel2(
    paths = sentinel_urls,
    cache=False,
    bands=["B02", "B03", "B04", "B08"],

)

sampler = GridGeoSampler(dataset, size=(128, 128), stride=(10, 10))

# Since we specify only 4 bands, we are expecting 8 hits, 2 timestamps per band.
print(len(sampler.hits))
# Here I would expect to have 2 samples (1 per timestamp)
print(len(sampler))

dataloader = DataLoader(dataset, sampler=sampler, batch_size=1,  collate_fn=stack_samples, num_workers=8)

for batch in dataloader:
    for sample in unbind_samples(batch):
        print(sample['image'].sum())
        print(sample['image'].shape)

Version

0.6.0.dev0, e973c1e3ca126f9f76573f43d8bfba910fa78efa

adamjstewart commented 1 month ago

Impact

This bug only occurs when a list of files is provided (either URLs or paths). It does not occur when a directory is provided. To reproduce this issue more quickly, the list of URLs can be replaced with a list of files in the git checkout:

sentinel_urls = [url[47:] for url in sentinel_urls]                                      

Analysis

The underlying issue is that TorchGeo's R-tree index is designed to only hold a single band when separate_files == True. Normally, this is enforced using a filename_glob that contains the band name, so only one band gets picked up. However, filename_glob is only used when a directory is given, it is skipped for files.

Workaround

It's actually really easy to avoid this bug. Simply pass a single band per image instead of all bands in the file list. For example:

sentinel_urls = [url for url in sentinel_urls if 'B02' in url]                           

Solution

The easiest way to fix this (i.e., to prevent users from shooting themselves in the foot) would be to ensure that filename_glob matches all files passed by the user.

A much harder fix would be to completely change how we model separate_files == True such that all bands end up in the R-tree but are somehow fused into a single bounding box.

FYI @adriantre

sfalkena commented 1 month ago

Thanks for your fast response. I overread that this was the purpose of filename_glob. I still think that it is a bit counterintuitive that the class starts "looking" for the other bands, since I initially was expecting that it only uses the files that I provide to the dataset and no extra. However, I understand the reason of this implementation and think the supposed fix makes sense.

The harder fix that you proposed is one that intuitively makes more sense to me, since we are adding all files to the index, and a simple intersection returns us all files that we would want to merge. However so far, I couldn't think of a neat way to do the bookkeeping in the sampler class, but I am happy to further think with you on how to implement something like this if this method also has your preference.

sfalkena commented 1 month ago

With the PR approved, feel free to close this issue or if you want we can use it as a thread for future discussions on the sampler approach?

adamjstewart commented 1 month ago

I'll leave this open until the PR is merged. We can open a separate discussion (not issue) for sampler approaches if needed.

adriantre commented 1 month ago

This bug only occurs when a list of files is provided (either URLs or paths). It does not occur when a directory is provided.

This is only semi-relevant, but I checked and found that https "directories" can not be listed using fiona.listdir, but zip archives on https can.

Using listdir_vsi_recursive from #1399 (which use fiona):

listdir_vsi_recursive("https://github.com/microsoft/torchgeo/raw/main/tests/data/cms_mangrove_canopy/CMS_Global_Map_Mangrove_Canopy_1665")
> 
['https://github.com/microsoft/torchgeo/raw/main/tests/data/cms_mangrove_canopy/CMS_Global_Map_Mangrove_Canopy_1665']
# Does not list files, as fiona thinks it is a file (FionaValueError ... is not a directory, but exist.)

listdir_vsi_recursive("zip+https://github.com/microsoft/torchgeo/raw/main/tests/data/cms_mangrove_canopy/CMS_Global_Map_Mangrove_Canopy_1665.zip!/CMS_Global_Map_Mangrove_Canopy_1665")

> 
['zip+https://github.com/microsoft/torchgeo/raw/main/tests/data/cms_mangrove_canopy/CMS_Global_Map_Mangrove_Canopy_1665.zip!/CMS_Global_Map_Mangrove_Canopy_1665/data/Mangrove_hmax95_Angola.tif',
 'zip+https://github.com/microsoft/torchgeo/raw/main/tests/data/cms_mangrove_canopy/CMS_Global_Map_Mangrove_Canopy_1665.zip!/CMS_Global_Map_Mangrove_Canopy_1665/data/Mangrove_hba95_Angola.tif',
 'zip+https://github.com/microsoft/torchgeo/raw/main/tests/data/cms_mangrove_canopy/CMS_Global_Map_Mangrove_Canopy_1665.zip!/CMS_Global_Map_Mangrove_Canopy_1665/data/Mangrove_agb_Angola.tif']
# Does list

Edit: https can not be listed like the other Virtual File Systems. Toblerity/Fiona#1429