microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.75k stars 345 forks source link

Add support for Lightning Streaming Dataset #1886

Open tchaton opened 9 months ago

tchaton commented 9 months ago

Summary

Dear people of TorchGeo,

I am hearing more folks in the community training in the cloud and struggling with their data streaming solution.

I have been working on this new framework: https://github.com/Lightning-AI/pytorch-lightning/tree/master/src/lightning/data.

I wondered if you would be interested to give it a spin.

Rationale

Make it easier for people to consume datasets in the cloud

Implementation

No response

Alternatives

No response

Additional information

No response

adamjstewart commented 9 months ago

@robmarkcole has been exploring this and may be able to comment on how feasible it is to integrate this into TorchGeo.

robmarkcole commented 9 months ago

It is pretty straightforward in 2 steps: (1) process the dataset into the required binary format and host on a cloud - I am using AWS, TBC which could would host this for these public datasets, (2) each dataset then has a complementary streaming version

adamjstewart commented 9 months ago

each dataset then has a complementary streaming version

This is what I'm trying to avoid. I don't want 2+ copies of all 75+ datasets in TorchGeo. I'm fine with adding a new subclass of GeoDataset or RasterDataset though.

robmarkcole commented 9 months ago

I'm not sure there is a way around it. In my implementation the data module accepts a 'streaming' arg and returns the regular or streaming dataset

tchaton commented 9 months ago

Yes, unfortunately, having a copy is the only way to make things fast. This is maybe something we could help with. I will come back to you.

isaaccorley commented 9 months ago

The fMoW dataset would be a good use case for this. It's hosted in s3 and is already preprocessed into individual image patches instead of larger tiles.

adriantre commented 8 months ago

Does Lightning Streaming Dataset take into account random reading of subregions within a file? In geospatial dataformats there are some caveats to this when choosing low file size (.jp2) vs fast random reading (.geotiff).

tchaton commented 8 months ago

Hey @adriantre. Right now, the entire file is downloaded and the window is applied locally. However, I think it would be interesting to add slicing on s3 directly if possible.

adriantre commented 8 months ago

I see! When training using torchgeo, say batch_size = 8, it is common that each slice/sample is read from 8 different images. Each image are commonly 1GB+ size. Then next batch is read from either new or the same images.

Indeed, slicing on s3 directly is possible (for GDAL-compliant dataformats).

tchaton commented 8 months ago

Hey @adriantre, thanks. That's fascinating. I will have a look and see if we can add support for window fetching directly. This would be super neat. Do you know how they do it under the hood ? Otherwise, I will investigate.

But yeah, for the time being. This provides value when the files are smaller than the chunk size e.g the dataset has being pre-processed into smaller tiles already.

adamjstewart commented 8 months ago

When training using torchgeo, say batch_size = 8, it is common that each slice/sample is read from 8 different images.

This is only true if you use RandomGeoSampler. We implemented RandomBatchGeoSampler for this exact scenario. With the latter, each mini-batch will only consist of patches from a single image.

But I agree that windowed-reading support (such as implemented by GDAL/rasterio) within S3 would be nice for streaming.

adriantre commented 8 months ago

Hey @adriantre, thanks. That's fascinating. I will have a look and see if we can add support for window fetching directly. This would be super neat. Do you know how they do it under the hood ? Otherwise, I will investigate.

To my understanding, it is enabled by the file formats dividing and storing big files as "smaller files" called blocks. The blocks can be accessed in parallell threads and can efficiently be accessed without scanning through the whole file. Kinda like an index in a database (r-tree). The block size is an optimisation parameter that influence random-reading speed and writing speed and file size. And the per-fileformat-driver knows how to utilise this. Thats as far as my understanding goes thought 😅

Might be some more in-depth info here: http://ikcest-drr.osgeo.cn/tutorial/k1072

tchaton commented 8 months ago

Interesting, similar to what the streaming dataset does with the chunks. However, I am chatting with AWS Team to support a more native solution.

adriantre commented 8 months ago

What I did not mention, that torchgeo to some degree relies on, is that the drivers for reading these datasets has wrappers that let us read windows from the dataset by specifying geospatial coordinates (not only pixel bounds). And the returned dataset handler let us reproject the pixels to a desired spatial reference system and resolution.

I think it is worth keeping this in mind for the implementation in Lightning Streaming Dataset to lower the barrier for torchgeo to adapt it to rasterio/gdal.

robmarkcole commented 6 months ago

I've now created a couple of streaming datasets and have a common class that can handle them all - it just requires the datasets are formatted in a standard way as a dict, which actually is the same format required for transforms anyway:

import torch
from litdata import StreamingDataset
from rasterio.io import MemoryFile

class SegmentationStreamingDataset(StreamingDataset):
    """
    Segmentation dataset with streaming.

    Args:
        input_dir (str): Local directory or S3 location of the dataset
        transforms (Optional[Callable]): A transform that takes in an image and returns a transformed version.
    """

    def __init__(self, *args, transforms=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.transforms = transforms

    def __getitem__(self, index) -> dict:
        data = super().__getitem__(index)
        image_name = data["name"]
        image = data["image"]
        mask = data["mask"]

        with MemoryFile(image) as memfile:
            with memfile.open() as dataset:
                image = torch.from_numpy(dataset.read()).float()

        with MemoryFile(mask) as memfile:
            with memfile.open() as dataset:
                mask = torch.from_numpy(dataset.read()).long()

        sample = {"image": image, "mask": mask, "image_name": image_name}
        if self.transforms is not None:
            sample = self.transforms(sample)
        return sample