Closed grantcurell closed 8 months ago
We could probably solve this with a Resize
augmentation but let's actually choose a simpler dataset, VHR-10 is kind of complicated. @ashnair1 is working on a data module for VHR-10 which will make it easier to use, and will include augmentations like this: #1082
I'm not sure if you want it, but after staring at this for awhile to figure out what it is I'm looking at (never used any of this stuff before) this is what I came up with:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchgeo.datasets import VHR10
# Define the resize transform
resize_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((1024, 1024)),
transforms.ToTensor()
])
# Custom collate function
def custom_collate(batch):
images = [item["image"] for item in batch]
labels = [item["labels"] for item in batch]
resized_images = [resize_transform(img) for img in images]
resized_images = torch.stack(resized_images)
# Since labels can have different lengths, we keep them as a list instead of stacking
return {"image": resized_images, "labels": labels}
# Initialize the dataset
dataset = VHR10(root="./raw_data", download=True, checksum=True)
# Initialize the dataloader with the custom collate function
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, collate_fn=custom_collate)
# Training loop
for batch in dataloader:
image = batch["image"]
labels = batch["labels"]
I can PR it with an explanation for noobies if you want but like you said I'm not sure if it's what you want.
Our datasets aren't really compatible with torchvision transforms, you'll have much better luck with kornia transforms. Something like:
from kornia.augmentation import Resize
from torchgeo.transforms import AugmentationSequential
transforms = AugmentationSequential(
Resize(..., ...), data_keys=["image"]
)
See https://torchgeo.readthedocs.io/en/stable/tutorials/transforms.html for more examples.
I tried this and I don't think it's compatible with our current design of VHR10. This should be reworked in #1082. In the meantime, it's probably easier to give an example using a different dataset where images don't require resizing.
@grantcurell do you want to submit a PR to change the example dataset from VHR10 to EuroSAT while we wait for #1082 to be merged? EuroSAT should be a much simpler example.
@grantcurell do you want to submit a PR to change the example dataset from VHR10 to EuroSAT while we wait for #1082 to be merged? EuroSAT should be a much simpler example.
Apologies for my delayed response. I've already done all my other modeling with the VHR10 dataset so for me that's what I'll probably stick with.
If I get the chance though, I'll write something up for EuroSAT.
We could probably solve this with a
Resize
augmentation but let's actually choose a simpler dataset, VHR-10 is kind of complicated. @ashnair1 is working on a data module for VHR-10 which will make it easier to use, and will include augmentations like this: #1082
@adamjstewart Can you explain why its complicated? I'm facing a similar issue with the chesapeake dataset. I read the corresponding datamodule code, but its unclear why resizing the image before applying a crop is the ideal solution, especially for applications where the pixel resolution matters. Furthermore, each tile in the dataset is quite large, so I'm also not sure why this resizing is even necessary.
VHR-10 is complicated because it has images, masks, and bounding boxes. Chesapeake only has images and masks, so it's much easier. I think the problem with Chesapeake is slightly different since it's a GeoDataset. In theory, resize/crop shouldn't be needed, but it's needed right now because it's not using the RasterDataset base class. If you open a separate issue maybe @calebrob6 can take a look at fixing this.
@ashnair1 has this gotten any better now that #1082 has been merged? Or should we switch to a simpler dataset that does not require transforms to use?
The following example will work. Though a simpler dataset might be better suited for the README.
import kornia.augmentation as K
import torch
from torch.utils.data import DataLoader
from torchgeo.datamodules.utils import AugPipe, collate_fn_detection
from torchgeo.datasets import VHR10
from torchgeo.transforms import AugmentationSequential
batch_size = 2
# Initialize the dataset
dataset = VHR10(root="./raw_data/", download=True, checksum=True)
# Initialize the dataloader with the custom collate function
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=collate_fn_detection,
)
# Initialize augs to normalize and resize images to size (512, 512)
aug = AugPipe(
augs=AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.Resize((512, 512)),
data_keys=["image", "boxes", "masks"],
),
batch_size=batch_size,
)
# Training loop
for batch in dataloader:
batch = aug(batch)
images = batch["image"] # List of images
boxes = batch["boxes"] # List of boxes
labels = batch["labels"] # List of labels
masks = batch["masks"] # List of masks
I do really like the VHR-10 pic we use in the README though... Want to submit a PR to use that code to fix the README example? I would also except a PR that uses a different dataset like EuroSAT instead. We're trying to release 0.5.2 tomorrow or Saturday so it kinda needs to happen fast if we want to get this fixed before the next release.
Ok, let's go with VHR-10 (#1920) for now. We can always switch later.
Description
I started by copying and pasting the example as is from the frontpage:
This produces:
Steps to reproduce
Version
0.5.0