microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.75k stars 345 forks source link

Problem training with standard dataloader. #1426

Open lcoandrade opened 1 year ago

lcoandrade commented 1 year ago

Description

I've just learnt about Torchgeo and got interested in using it. So, I created a Kaggle notebook to test it with NAIP and Chesapeake data (Torchgeo 101). When I try to train a segmentation task, I get the following error: ValueError: A frozen dataclass was passed to `apply_to_collection` but this is not allowed.

Steps to reproduce

  1. Create a dataset with NAIP and Chesapeake data:
    
    # Creating the NAIP dataset
    naip_root = os.path.join(INPUT_DIR, 'naip')
    naip = NAIP(naip_root)

Creating the CHESAPEAKE dataset

chesapeake_root = os.path.join(INPUT_DIR, "chesapeake") chesapeake = ChesapeakeDE( chesapeake_root, crs=naip.crs, res=naip.res, download=False )

2. Make an intersection, create a sampler and a dataloader:

dataset = naip & chesapeake sampler = RandomGeoSampler(dataset, size=IMG_SIZE, length=SAMPLE_SIZE) dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)

3. Define a trainer:

DEVICE, NUM_DEVICES = ("cuda", torch.cuda.device_count()) if torch.cuda.is_available() else ("cpu", mp.cpu_count()) WORKERS = mp.cpu_count() print(f'Running on {NUM_DEVICES} {DEVICE}(s)')

trainer = pl.Trainer( accelerator=DEVICE, devices=NUM_DEVICES, max_epochs=EPOCHS, callbacks=[checkpoint_callback, ], logger=logger, )

4. Define a segmentation task:

ssl._create_default_https_context = ssl._create_unverified_context

test_dir = os.path.join(OUTPUT_DIR, "test") if not os.path.exists(test_dir): os.makedirs(test_dir)

logger = CSVLogger( test_dir, name='torchgeo_logs' )

checkpoint_callback = ModelCheckpoint( every_n_epochs=1, dirpath=test_dir, filename='torchgeo_trained' )

task = SemanticSegmentationTask( model = SEGMENTATION_MODEL, backbone = BACKBONE, weights = WEIGHTS, in_channels = IN_CHANNELS, num_classes = NUM_CLASSES, loss = LOSS, ignore_index = None, learning_rate = LR, learning_rate_schedule_patience = PATIENCE, )

5. Start training:

trainer.fit( model=task, train_dataloaders=dataloader, )



### Version

0.4.1
adamjstewart commented 1 year ago

Duplicate of #1056 and #1418

The issue is that some of the sample values returned by GeoDataset can't be automatically collated by PyTorch (BoundingBox, CRS). Our solution for our builtin data modules is to remove these values before loading: https://github.com/microsoft/torchgeo/blob/v0.4.1/torchgeo/datamodules/geo.py#L280

My suggestion would be to write a simple data module (there are dozens of builtin examples) and use that instead of directly using a data loader. Maybe this is something we could add to our collation functions...

adamjstewart commented 1 year ago

Is this still an issue or can this be closed?

lcoandrade commented 1 year ago

I've made a CustomGeoDatamodule like this:

class CustomGeoDataModule(GeoDataModule):
    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        self.dataset = self.dataset_class(**self.kwargs)

        generator = torch.Generator().manual_seed(0)
        (
            self.train_dataset,
            self.val_dataset,
            self.test_dataset,
        ) = random_bbox_assignment(dataset, [0.6, 0.2, 0.2], generator)

        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(
                self.train_dataset, self.patch_size, self.batch_size, self.length
            )
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(
                self.val_dataset, self.patch_size, self.patch_size
            )
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(
                self.test_dataset, self.patch_size, self.patch_size
            )

To solve my problem.

trchudley commented 1 year ago

My suggestion would be to write a simple data module (there are dozens of builtin examples) and use that instead of directly using a data loader. Maybe this is something we could add to our collation functions...

Hi @adamjstewart

I've also encountered this problem, and it's taken me a while to find the solution. Definitely +1 for adding this as a feature of torchgeo to make this as seamless as possible for the end-users using GeoDatasets.

Cheers, Tom

adamjstewart commented 1 year ago

Reopening as a reminder to try to upstream some of our changes to PyTorch.