Open lcoandrade opened 1 year ago
Duplicate of #1056 and #1418
The issue is that some of the sample values returned by GeoDataset can't be automatically collated by PyTorch (BoundingBox, CRS). Our solution for our builtin data modules is to remove these values before loading: https://github.com/microsoft/torchgeo/blob/v0.4.1/torchgeo/datamodules/geo.py#L280
My suggestion would be to write a simple data module (there are dozens of builtin examples) and use that instead of directly using a data loader. Maybe this is something we could add to our collation functions...
Is this still an issue or can this be closed?
I've made a CustomGeoDatamodule like this:
class CustomGeoDataModule(GeoDataModule):
def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.dataset = self.dataset_class(**self.kwargs)
generator = torch.Generator().manual_seed(0)
(
self.train_dataset,
self.val_dataset,
self.test_dataset,
) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)
if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)
To solve my problem.
My suggestion would be to write a simple data module (there are dozens of builtin examples) and use that instead of directly using a data loader. Maybe this is something we could add to our collation functions...
Hi @adamjstewart
I've also encountered this problem, and it's taken me a while to find the solution. Definitely +1 for adding this as a feature of torchgeo to make this as seamless as possible for the end-users using GeoDatasets.
Cheers, Tom
Reopening as a reminder to try to upstream some of our changes to PyTorch.
Description
I've just learnt about Torchgeo and got interested in using it. So, I created a Kaggle notebook to test it with NAIP and Chesapeake data (Torchgeo 101). When I try to train a segmentation task, I get the following error:
ValueError: A frozen dataclass was passed to `apply_to_collection` but this is not allowed.
Steps to reproduce
Creating the CHESAPEAKE dataset
chesapeake_root = os.path.join(INPUT_DIR, "chesapeake") chesapeake = ChesapeakeDE( chesapeake_root, crs=naip.crs, res=naip.res, download=False )
dataset = naip & chesapeake sampler = RandomGeoSampler(dataset, size=IMG_SIZE, length=SAMPLE_SIZE) dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)
DEVICE, NUM_DEVICES = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count()) WORKERS = mp.cpu_count() print(f'Running on {NUM_DEVICES} {DEVICE}(s)')
trainer = pl.Trainer( accelerator=DEVICE, devices=NUM_DEVICES, max_epochs=EPOCHS, callbacks=[checkpoint_callback, ], logger=logger, )
ssl._create_default_https_context = ssl._create_unverified_context
test_dir = os.path.join(OUTPUT_DIR, "test") if not os.path.exists(test_dir): os.makedirs(test_dir)
logger = CSVLogger( test_dir, name='torchgeo_logs' )
checkpoint_callback = ModelCheckpoint( every_n_epochs=1, dirpath=test_dir, filename='torchgeo_trained' )
task = SemanticSegmentationTask( model = SEGMENTATION_MODEL, backbone = BACKBONE, weights = WEIGHTS, in_channels = IN_CHANNELS, num_classes = NUM_CLASSES, loss = LOSS, ignore_index = None, learning_rate = LR, learning_rate_schedule_patience = PATIENCE, )
trainer.fit( model=task, train_dataloaders=dataloader, )