microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.75k stars 343 forks source link

Front page example for VHR10 dataset does not work #1686

Closed grantcurell closed 8 months ago

grantcurell commented 1 year ago

Description

I started by copying and pasting the example as is from the frontpage:

from torch.utils.data import DataLoader

from torchgeo.datasets import VHR10

dataset = VHR10(root="./raw_data", download=True, checksum=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

for batch in dataloader:
    image = batch["image"]
    label = batch["label"]

    # train a model, or make predictions using a pre-trained model

This produces:

/usr/bin/python3.10 /home/grant/Documents/code/geo_testing/NAIP_test/3_test.py 
Files already downloaded and verified
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
Traceback (most recent call last):
  File "/home/grant/Documents/code/geo_testing/NAIP_test/3_test.py", line 8, in <module>
    for batch in dataloader:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils.py", line 694, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 127, in collate
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 127, in <dictcomp>
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 162, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 663, 794] at entry 0 and [3, 551, 808] at entry 1

Steps to reproduce

  1. Copy and paste the code from here
  2. Update the root directory to something valid
  3. Run with no additional options

Version

0.5.0

adamjstewart commented 1 year ago

We could probably solve this with a Resize augmentation but let's actually choose a simpler dataset, VHR-10 is kind of complicated. @ashnair1 is working on a data module for VHR-10 which will make it easier to use, and will include augmentations like this: #1082

grantcurell commented 1 year ago

I'm not sure if you want it, but after staring at this for awhile to figure out what it is I'm looking at (never used any of this stuff before) this is what I came up with:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchgeo.datasets import VHR10

# Define the resize transform
resize_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((1024, 1024)),
    transforms.ToTensor()
])

# Custom collate function
def custom_collate(batch):
    images = [item["image"] for item in batch]
    labels = [item["labels"] for item in batch]

    resized_images = [resize_transform(img) for img in images]
    resized_images = torch.stack(resized_images)

    # Since labels can have different lengths, we keep them as a list instead of stacking
    return {"image": resized_images, "labels": labels}

# Initialize the dataset
dataset = VHR10(root="./raw_data", download=True, checksum=True)

# Initialize the dataloader with the custom collate function
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, collate_fn=custom_collate)

# Training loop
for batch in dataloader:
    image = batch["image"]
    labels = batch["labels"]

I can PR it with an explanation for noobies if you want but like you said I'm not sure if it's what you want.

adamjstewart commented 1 year ago

Our datasets aren't really compatible with torchvision transforms, you'll have much better luck with kornia transforms. Something like:

from kornia.augmentation import Resize
from torchgeo.transforms import AugmentationSequential

transforms = AugmentationSequential(
    Resize(..., ...), data_keys=["image"]
)

See https://torchgeo.readthedocs.io/en/stable/tutorials/transforms.html for more examples.

I tried this and I don't think it's compatible with our current design of VHR10. This should be reworked in #1082. In the meantime, it's probably easier to give an example using a different dataset where images don't require resizing.

adamjstewart commented 1 year ago

@grantcurell do you want to submit a PR to change the example dataset from VHR10 to EuroSAT while we wait for #1082 to be merged? EuroSAT should be a much simpler example.

grantcurell commented 1 year ago

@grantcurell do you want to submit a PR to change the example dataset from VHR10 to EuroSAT while we wait for #1082 to be merged? EuroSAT should be a much simpler example.

Apologies for my delayed response. I've already done all my other modeling with the VHR10 dataset so for me that's what I'll probably stick with.

If I get the chance though, I'll write something up for EuroSAT.

connorlee77 commented 11 months ago

We could probably solve this with a Resize augmentation but let's actually choose a simpler dataset, VHR-10 is kind of complicated. @ashnair1 is working on a data module for VHR-10 which will make it easier to use, and will include augmentations like this: #1082

@adamjstewart Can you explain why its complicated? I'm facing a similar issue with the chesapeake dataset. I read the corresponding datamodule code, but its unclear why resizing the image before applying a crop is the ideal solution, especially for applications where the pixel resolution matters. Furthermore, each tile in the dataset is quite large, so I'm also not sure why this resizing is even necessary.

adamjstewart commented 11 months ago

VHR-10 is complicated because it has images, masks, and bounding boxes. Chesapeake only has images and masks, so it's much easier. I think the problem with Chesapeake is slightly different since it's a GeoDataset. In theory, resize/crop shouldn't be needed, but it's needed right now because it's not using the RasterDataset base class. If you open a separate issue maybe @calebrob6 can take a look at fixing this.

adamjstewart commented 8 months ago

@ashnair1 has this gotten any better now that #1082 has been merged? Or should we switch to a simpler dataset that does not require transforms to use?

ashnair1 commented 8 months ago

The following example will work. Though a simpler dataset might be better suited for the README.

import kornia.augmentation as K
import torch
from torch.utils.data import DataLoader

from torchgeo.datamodules.utils import AugPipe, collate_fn_detection
from torchgeo.datasets import VHR10
from torchgeo.transforms import AugmentationSequential

batch_size = 2

# Initialize the dataset
dataset = VHR10(root="./raw_data/", download=True, checksum=True)

# Initialize the dataloader with the custom collate function
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn_detection,
)

# Initialize augs to normalize and resize images to size (512, 512)
aug = AugPipe(
    augs=AugmentationSequential(
        K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
        K.Resize((512, 512)),
        data_keys=["image", "boxes", "masks"],
    ),
    batch_size=batch_size,
)

# Training loop
for batch in dataloader:
    batch = aug(batch)
    images = batch["image"] # List of images
    boxes = batch["boxes"] # List of boxes
    labels = batch["labels"] # List of labels
    masks = batch["masks"] # List of masks
adamjstewart commented 8 months ago

I do really like the VHR-10 pic we use in the README though... Want to submit a PR to use that code to fix the README example? I would also except a PR that uses a different dataset like EuroSAT instead. We're trying to release 0.5.2 tomorrow or Saturday so it kinda needs to happen fast if we want to get this fixed before the next release.

ashnair1 commented 8 months ago

Ok, let's go with VHR-10 (#1920) for now. We can always switch later.