Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.42k stars 3.29k forks source link

Add Hooks for Dataloader Beginning and End #18019

Open Queuecumber opened 12 months ago

Queuecumber commented 12 months ago

Description & Motivation

I have a task with multiple test datasets and I want to compute metrics individually per dataset. This requires me to take a custom action when the dataset ends. It would be great to have built-in hooks in lightning to enable this and it already fits in with the overall philosophy of the lightningmodule processing hooks enabling actions at specific points in the loop.

Pitch

I propose adding two new hooks for validation/test/prediction (training is probably not applicable here but we can discuss if it is). This would be the on__dataloaderstart and on_dataloader_end hooks. They would be called as follows:

on_X_epoch_start on_X_dataloader_start ---------------| on_X_batch_start -| batch loop | dataloader loop on_X_batch_end -| | on_X_dataloader_end-----------------| on_X_epoch_end

There would also need to be an update to allow metric logging in this manner because currently lightning throws an exception if the same key is logged "with different metadata" (I forget the exact text of the exception)

Alternatives

I currently provide a subset of this functionality (the *_dataloader_end hooks) using the following callback:

from typing import Any

from lightning.pytorch import Callback, LightningModule, Trainer

class MultiloaderNotifier(Callback):
    def __init__(self) -> None:
        self.dataloader_idxs = {"validation": 0, "test": 0, "predict": 0}

    def on_batch_start(self, stage: str, pl_module: LightningModule, dataloader_idx: int) -> None:
        if dataloader_idx != self.dataloader_idxs[stage]:
            if hook := getattr(pl_module, f"on_{stage}_dataloader_end", None):
                hook(self.dataloader_idxs[stage])
        self.dataloader_idxs[stage] = dataloader_idx

    def on_epoch_end(self, stage: str, pl_module: LightningModule) -> None:
        if hook := getattr(pl_module, f"on_{stage}_dataloader_end", None):
            hook(self.dataloader_idxs[stage])
        self.dataloader_idxs[stage] = 0

    def on_validation_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        self.on_batch_start("validation", pl_module, dataloader_idx)

    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        self.on_epoch_end("validation", pl_module)

    def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        self.on_batch_start("test", pl_module, dataloader_idx)

    def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        self.on_epoch_end("test", pl_module)

    def on_predict_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        self.on_batch_start("predict", pl_module, dataloader_idx)

    def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        self.on_epoch_end("predict", pl_module)

which allows me to hook with

def on_test_dataloader_end(self, dataloader_idx: int) -> None:
        self.logger.log_metric(self.metric.compute())
        self.metric.reset()

However this has a number of drawbacks. First, the hooks are not necessarily called in the right order since I need to hook batch_start from the callback in order to detect the change in dataloader (so batch_start -> dataloader_end is not conceptually correct). Next, I need to hook epoch_end to make the last dataloader_end call (epoch_end -> dataloader_end is also not conceptually correct). Lastly there is no obvious way to do both dataloader_start and dataloader_end, I don't need them both but someone else might.

There is also the issue of the logging itself. Currently, to avoid the exception, I need to bypass the lightningmodule log function and call log_metric on the logger.

Additional context

No response

cc @borda @justusschock @awaelchli @carmocca @Blaizzy

awaelchli commented 10 months ago

I believe this is also loosely related to #18282. We should look into how these limitations can be addressed so it is easier to work with dataloader boundaries when using multiple.

Queuecumber commented 10 months ago

Yeah it's related, but I think there are some good reasons for not resetting metrics between dataloaders, or at least making it configurable; it depends on the use case

guillaume-rochette-oxb commented 2 months ago

Hi, (Sorry for the very late reply :/) I do not think that there's a very good reason for not resetting in-between several dataloaders, as we can create a ConcatDataset instance from multiple Dataset instances. Unless there's a use case, I've missed?

guillaume-rochette-oxb commented 1 month ago

Any update? @awaelchli @carmocca

guillaume-rochette-oxb commented 1 month ago

I've done a little bit of digging, and basically the issue happens in _EvaluationLoop.run() Personal comments are indicated with START and END.

    @_no_grad_context
    def run(self) -> List[_OUT_DICT]:
        self.setup_data()
        if self.skip:
            return []
        self.reset()
        self.on_run_start()
        data_fetcher = self._data_fetcher
        assert data_fetcher is not None
        previous_dataloader_idx = 0
        while True:
            try:
                if isinstance(data_fetcher, _DataLoaderIterDataFetcher):
                    dataloader_iter = next(data_fetcher)
                    # hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting
                    batch = data_fetcher._batch
                    batch_idx = data_fetcher._batch_idx
                    dataloader_idx = data_fetcher._dataloader_idx
                else:
                    dataloader_iter = None
                    batch, batch_idx, dataloader_idx = next(data_fetcher)
                if previous_dataloader_idx != dataloader_idx:
                    # the dataloader has changed, notify the logger connector
                    # START
                    self._store_dataloader_outputs()
                    # END
                    # For all the dataloaders but the last one, the metrics, for which we have a reference
                    # are getting computed, but not they do not get reset.
                previous_dataloader_idx = dataloader_idx
                self.batch_progress.is_last_batch = data_fetcher.done
                # run step hooks
                self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
            except StopIteration:
                # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
                break
            finally:
                self._restarting = False
        # START
        self._store_dataloader_outputs()
        return self.on_run_end()
        # END
        # For the last dataloader, the metrics are computed, and then finally reset in the `self.on_run_end()`

I think that this does not make much sense, even in the event where we would be having several metric objects. Because the retaining the metric states after they get computed is just retaining resources that would get freed otherwise. I took the example of the MeanAveragePrecision in #18282 because if we were to compute it on instance masks, they all have to be retained on the CPU, which could be very expensive if the number of dataloaders and samples would grow large.