Open tchaton opened 2 years ago
def test_multiple_dataloaders_logging(tmpdir):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx, dataloader_idx):
self.log("value_1", dataloader_idx, add_dataloader_idx=False)
isn't this incorrect behavior since we have a single resultcollection instance handling all the keys but value_1
here is common for both the dataloaders. Else it'd be difficult to check which value to monitor for eg in the case of ReduceLROnPlateau/ModelCheckpoint/EarlyStopping.
Hey @rohitgr7.
What do you think should be the default behavior here ? Not sure to understand your answer.
@tchaton I believe add_dataloader_idx
is meant to distinguish the metrics stored internally by appending the index in front of it. The flow with multiple dataloaders is like this:
complete val_loop for dataloader 1
complete val_loop for dataloader 2
...
call checkpoint callback with a monitor value
now if we have a metric being logged with the same key used with both the dataloaders logged with add_dataloader_idx=False
, how do we figure out for which dataloader to use the metric from while monitoring the checkpoint callback?
(Leaving my thoughts written after discussing online).
The cross-reduction does make sense considering what's supposed to happen with dataloader_idx=False
:
add_dataloader_idx: if ``True``, appends the index of the current dataloader to
the name (when using multiple dataloaders). If False, user needs to give unique names for
each dataloader to not mix the values.
This means the expectation is that the values are mixed in the dataloader_idx=False
case.
One problem with this is that it contradicts the assertions written in https://github.com/PyTorchLightning/pytorch-lightning/pull/10810 where the test added separates the values logged by their respective dataloader indices when the key matches.
In total, we have 3 options:
So I'd go to either (1): the easy choice, or (3): the complex choice.
One more thing to note is that the user can always manually post-process the results of (1) and obtain (3), whereas it's not possible the other way around.
@tchaton I can work on this when we choose one of the options as I've got a draft implementation (mostly) working.
how do we figure out for which dataloader to use the metric from while monitoring the checkpoint callback?
This is an inherent limitation of the design, where trainer.callback_metrics
is a flat dictionary so collisions will overwrite the values. But that's a problem for the 3 options described in my previous comment.
IMO, it's a good thing to support cross-reduction across dataloaders, but I'd argue from the user's point of view that what we have on master is good enough right now and would expect lightning to let me know that I made a mistake of using same keys for both of my dataloaders with an error.
If we support cross-reduction, then using the same key for multiple dataloaders is not an error but a feature, as it would be the mechanism to do it.
so it will just replace the common key value with the latest one when using let's say 10+ dataloaders and users might not see correct graphs on their loggers realizing later that since they have used the same key, the metrics for other dataloaders is gone?
If the 10 dataloaders use the same key and we support (3), the process would be:
dataloader_idx
.Hey @carmocca @rohitgr7. I would personally support @carmocca option 3: Fully implement cross-reduction
with the following behavior
def ...
# This won't reduce metrics across dataloaders and adds or not the suffix
with dataloader_idx for curve visualization.
self.log("key", value, add_dataloader_idx=True)
self.log("key", value, add_dataloader_idx=False) -> perform compute after each dataloader loop.
# This won't reduce metrics across dataloaders and adds or not the suffix
with dataloader_idx for curve visualization.
self.log("key", value, add_dataloader_idx=True, cross_dataloader_reduction=True) -> raise a misconfiguration
self.log("key", value, cross_dataloader_reduction=True) -> perform compute once all dataloader loop are completed
I am currently having this issue trying to log metrics for every dataloader individually while also logging the same metric over all dataloaders combined. And as mentioned above I was fully expecting this to just work after reading:
add_dataloader_idx: If True, appends the index of the current dataloader to the name
(when using multiple dataloaders). If False, user needs to give unique names for each
dataloader to not mix the values.
Any news on when this will be supported?
I just came across this issue when developing the following scenario.
I've made a LightningDataModule
that produces various validation and testing DataLoader
s.
Conceptually each DataLoader
is associated with its own (identical) set of metrics (right now all torchmetrics
), I'd also like to have an aggregate set of these metrics for all validation and testing DataLoader
s.
This relationship is implemented at the LightningModule
level, which in my opinion is not great but it doesn't matter for this issue.
The problem arises when I tried using LightningModule.log("validation/metric", self.metric_obj, add_dataloader_idx=False)
for logging the aggregate metric at validation_step (same for test_step).
My current workaround is updating the validation/test metrics at (phase)_step
time and then logging them using the on_(phase)_epoch_end
callbacks. This is a compromise that works for me as I do not need step-level plots for validation and testing, but others might not be so lucky.
I'd also like to add that I would like to do the same for training but after reading these two issues https://github.com/Lightning-AI/lightning/issues/9352, https://github.com/Lightning-AI/lightning/issues/15688 I just decided to concatenate the training Datasets and ignore training metrics breakdown.
Any updates on this issue?
🐛 Bug
To Reproduce
The current behavior with
add_dataloader_idx
seems quite confusing to me. As a user, I don't know if I would expect to get the value reduced across all dataloaders and be added to both results objects.Expected behavior
Environment
conda
,pip
, source):torch.__config__.show()
:Additional context
cc @tchaton