Open tchaton opened 9 months ago
@robmarkcole has been exploring this and may be able to comment on how feasible it is to integrate this into TorchGeo.
It is pretty straightforward in 2 steps: (1) process the dataset into the required binary format and host on a cloud - I am using AWS, TBC which could would host this for these public datasets, (2) each dataset then has a complementary streaming version
each dataset then has a complementary streaming version
This is what I'm trying to avoid. I don't want 2+ copies of all 75+ datasets in TorchGeo. I'm fine with adding a new subclass of GeoDataset or RasterDataset though.
I'm not sure there is a way around it. In my implementation the data module accepts a 'streaming' arg and returns the regular or streaming dataset
Yes, unfortunately, having a copy is the only way to make things fast. This is maybe something we could help with. I will come back to you.
The fMoW dataset would be a good use case for this. It's hosted in s3 and is already preprocessed into individual image patches instead of larger tiles.
Does Lightning Streaming Dataset take into account random reading of subregions within a file? In geospatial dataformats there are some caveats to this when choosing low file size (.jp2) vs fast random reading (.geotiff).
Hey @adriantre. Right now, the entire file is downloaded and the window is applied locally. However, I think it would be interesting to add slicing on s3 directly if possible.
I see! When training using torchgeo, say batch_size = 8
, it is common that each slice/sample is read from 8 different images. Each image are commonly 1GB+ size. Then next batch is read from either new or the same images.
Indeed, slicing on s3 directly is possible (for GDAL-compliant dataformats).
Hey @adriantre, thanks. That's fascinating. I will have a look and see if we can add support for window fetching directly. This would be super neat. Do you know how they do it under the hood ? Otherwise, I will investigate.
But yeah, for the time being. This provides value when the files are smaller than the chunk size e.g the dataset has being pre-processed into smaller tiles already.
When training using torchgeo, say
batch_size = 8
, it is common that each slice/sample is read from 8 different images.
This is only true if you use RandomGeoSampler
. We implemented RandomBatchGeoSampler
for this exact scenario. With the latter, each mini-batch will only consist of patches from a single image.
But I agree that windowed-reading support (such as implemented by GDAL/rasterio) within S3 would be nice for streaming.
Hey @adriantre, thanks. That's fascinating. I will have a look and see if we can add support for window fetching directly. This would be super neat. Do you know how they do it under the hood ? Otherwise, I will investigate.
To my understanding, it is enabled by the file formats dividing and storing big files as "smaller files" called blocks. The blocks can be accessed in parallell threads and can efficiently be accessed without scanning through the whole file. Kinda like an index in a database (r-tree). The block size is an optimisation parameter that influence random-reading speed and writing speed and file size. And the per-fileformat-driver knows how to utilise this. Thats as far as my understanding goes thought 😅
Might be some more in-depth info here: http://ikcest-drr.osgeo.cn/tutorial/k1072
Interesting, similar to what the streaming dataset does with the chunks. However, I am chatting with AWS Team to support a more native solution.
What I did not mention, that torchgeo to some degree relies on, is that the drivers for reading these datasets has wrappers that let us read windows from the dataset by specifying geospatial coordinates (not only pixel bounds). And the returned dataset handler let us reproject the pixels to a desired spatial reference system and resolution.
I think it is worth keeping this in mind for the implementation in Lightning Streaming Dataset to lower the barrier for torchgeo to adapt it to rasterio/gdal.
I've now created a couple of streaming datasets and have a common class that can handle them all - it just requires the datasets are formatted in a standard way as a dict, which actually is the same format required for transforms anyway:
import torch
from litdata import StreamingDataset
from rasterio.io import MemoryFile
class SegmentationStreamingDataset(StreamingDataset):
"""
Segmentation dataset with streaming.
Args:
input_dir (str): Local directory or S3 location of the dataset
transforms (Optional[Callable]): A transform that takes in an image and returns a transformed version.
"""
def __init__(self, *args, transforms=None, **kwargs):
super().__init__(*args, **kwargs)
self.transforms = transforms
def __getitem__(self, index) -> dict:
data = super().__getitem__(index)
image_name = data["name"]
image = data["image"]
mask = data["mask"]
with MemoryFile(image) as memfile:
with memfile.open() as dataset:
image = torch.from_numpy(dataset.read()).float()
with MemoryFile(mask) as memfile:
with memfile.open() as dataset:
mask = torch.from_numpy(dataset.read()).long()
sample = {"image": image, "mask": mask, "image_name": image_name}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
Summary
Dear people of TorchGeo,
I am hearing more folks in the community training in the cloud and struggling with their data streaming solution.
I have been working on this new framework: https://github.com/Lightning-AI/pytorch-lightning/tree/master/src/lightning/data.
I wondered if you would be interested to give it a spin.
Rationale
Make it easier for people to consume datasets in the cloud
Implementation
No response
Alternatives
No response
Additional information
No response