NKI-AI / ahcore

Ahcore is the AI for Oncology core computational pathology toolkit
Apache License 2.0
15 stars 1 forks source link

Improve DlupDataModule #90

Open VanessaBotha opened 2 months ago

VanessaBotha commented 2 months ago

I have come across some things that could improve the code quality/readability of the DlupDataModule code

1) self._already_called is never set. While constructing the dataloaders, we already check if the self._{stage}_data_iterator exists (if it is not None). This is sufficient, so I don't see the added value of keeping self._already_called

Suggestion: remove self._already_called

2) (bug)

    def test_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]:
        if not self._test_data_iterator:
            self.setup("test")
        batch_size = self._validate_batch_size if self._validate_batch_size else self._batch_size
        assert self._validate_data_iterator
        return self._construct_concatenated_dataloader(
            self._validate_data_iterator, batch_size=batch_size, stage="test"
        )

It should use the self._test_data_iterator not self._validate_data_iterator

Suggestion: replace self._validate_data_iterator with self._test_data_iterator

3) The data iterator type in the _construct_concatenated_dataloader method is data_iterator: Iterator[_DlupDataset] However, in _construct_concatenated_dataloader the data_iterator is allowed to be None:

        if not data_iterator:
            return None

Suggestions: a) Change data_iterator: Iterator[_DlupDataset] --> data_iterator: Iterator[_DlupDataset] | None. b) Or if not data_iterator, raise ValueError()

4) This is logged in _construct_concatenated_dataloader:

    lengths = np.asarray([len(ds) for ds in dataset.datasets])
    self._logger.info(
        f"Dataset for stage {stage} has {len(dataset)} samples and the following statistics:\n"
        f" - Mean: {lengths.mean():.2f}\n"
        f" - Std: {lengths.std():.2f}\n"
        f" - Min: {lengths.min():.2f}\n"
        f" - Max: {lengths.max():.2f}"
    )

Suggestion: add a method log_stats to the ConcatDataset class e.g.

def log_stats(self) -> None:
        lengths = np.asarray([len(ds) for ds in self.datasets])
        logger.info(
            f"Total number of samples: {len(self})\n"
            f" - Mean: {lengths.mean():.2f}\n"
            f" - Std: {lengths.std():.2f}\n"
            f" - Min: {lengths.min():.2f}\n"
            f" - Max: {lengths.max():.2f}"
        )

and in _construct_concatenated_dataloader:

self._logger.info(f"Dataset for stage {stage} has the following statistics:\n"
dataset.log_stats()

5) The variable stage is currently a string throughout the code. Suggestion: use Enum to keep track of the stage. Especially whithin datasets_from_data_description (manifest.py), it would be nice to use the CategoryEnum from database_models.py

6) self._limit_{stage}_samples is protected, but it is never set within DlupDataModule as far as I could see

Suggestions: a) add it as argument to the init function e.g. limit_samples: dict[str, int] then e.g. self._limit_validate_samples = limit_samples.get("limit_validate_samples", None) b) use kwargs in the init function eg. self._limit_validate_samples = kwargs.get("limit_validate_samples", None)

7) self._limit_{stage}_samples is limiting the number of datasets that are concatenated (so the number of WSIs), not the actual number of samples (which are the number of tiles).

data_iterator: Iterator[_DlupDataset]

        def construct_dataset() -> ConcatDataset:
            datasets = []
            for idx, ds in enumerate(data_iterator):
                datasets.append(ds)

                if limit_samples and idx >= limit_samples:
                    break

            return ConcatDataset(datasets=datasets)

Suggestion: rename self._limit_{stage}_samples to self._limit_{stage}_slides

8) self._limit_{stage}_samples can lead to problems in loading datasets from the cache: we check whether a cached dataset exists based on its UUID. However, this UUID is generated based on only the self.data_description. So if you run the datamodule first with e.g. self._limit_fit_samples=10 and then you run it a second time with self._limit_fit_samples=None , I expect it will load the incomplete dataset (with 10 slides) from the cache instead of constructing a new dataset Suggestions: a) instead of using only the UUID as pkl filename, add e.g. a suffix _{limit} if the limit is not None b) add the limit somewhere in the path in _load_from_cache