Open Queuecumber opened 12 months ago
I believe this is also loosely related to #18282. We should look into how these limitations can be addressed so it is easier to work with dataloader boundaries when using multiple.
Yeah it's related, but I think there are some good reasons for not resetting metrics between dataloaders, or at least making it configurable; it depends on the use case
Hi,
(Sorry for the very late reply :/)
I do not think that there's a very good reason for not resetting in-between several dataloaders, as we can create a ConcatDataset
instance from multiple Dataset
instances.
Unless there's a use case, I've missed?
Any update? @awaelchli @carmocca
I've done a little bit of digging, and basically the issue happens in _EvaluationLoop.run() Personal comments are indicated with START and END.
@_no_grad_context
def run(self) -> List[_OUT_DICT]:
self.setup_data()
if self.skip:
return []
self.reset()
self.on_run_start()
data_fetcher = self._data_fetcher
assert data_fetcher is not None
previous_dataloader_idx = 0
while True:
try:
if isinstance(data_fetcher, _DataLoaderIterDataFetcher):
dataloader_iter = next(data_fetcher)
# hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting
batch = data_fetcher._batch
batch_idx = data_fetcher._batch_idx
dataloader_idx = data_fetcher._dataloader_idx
else:
dataloader_iter = None
batch, batch_idx, dataloader_idx = next(data_fetcher)
if previous_dataloader_idx != dataloader_idx:
# the dataloader has changed, notify the logger connector
# START
self._store_dataloader_outputs()
# END
# For all the dataloaders but the last one, the metrics, for which we have a reference
# are getting computed, but not they do not get reset.
previous_dataloader_idx = dataloader_idx
self.batch_progress.is_last_batch = data_fetcher.done
# run step hooks
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
except StopIteration:
# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
break
finally:
self._restarting = False
# START
self._store_dataloader_outputs()
return self.on_run_end()
# END
# For the last dataloader, the metrics are computed, and then finally reset in the `self.on_run_end()`
I think that this does not make much sense, even in the event where we would be having several metric objects. Because the retaining the metric states after they get computed is just retaining resources that would get freed otherwise. I took the example of the MeanAveragePrecision in #18282 because if we were to compute it on instance masks, they all have to be retained on the CPU, which could be very expensive if the number of dataloaders and samples would grow large.
Description & Motivation
I have a task with multiple test datasets and I want to compute metrics individually per dataset. This requires me to take a custom action when the dataset ends. It would be great to have built-in hooks in lightning to enable this and it already fits in with the overall philosophy of the lightningmodule processing hooks enabling actions at specific points in the loop.
Pitch
I propose adding two new hooks for validation/test/prediction (training is probably not applicable here but we can discuss if it is). This would be the on__dataloaderstart and on_dataloader_end hooks. They would be called as follows:
on_X_epoch_start on_X_dataloader_start ---------------| on_X_batch_start -| batch loop | dataloader loop on_X_batch_end -| | on_X_dataloader_end-----------------| on_X_epoch_end
There would also need to be an update to allow metric logging in this manner because currently lightning throws an exception if the same key is logged "with different metadata" (I forget the exact text of the exception)
Alternatives
I currently provide a subset of this functionality (the *_dataloader_end hooks) using the following callback:
which allows me to hook with
However this has a number of drawbacks. First, the hooks are not necessarily called in the right order since I need to hook batch_start from the callback in order to detect the change in dataloader (so batch_start -> dataloader_end is not conceptually correct). Next, I need to hook epoch_end to make the last dataloader_end call (epoch_end -> dataloader_end is also not conceptually correct). Lastly there is no obvious way to do both dataloader_start and dataloader_end, I don't need them both but someone else might.
There is also the issue of the logging itself. Currently, to avoid the exception, I need to bypass the lightningmodule log function and call log_metric on the logger.
Additional context
No response
cc @borda @justusschock @awaelchli @carmocca @Blaizzy