When using MultiTaskDataLoader with more than one task and the instances_per_epoch feature, the number of batches in the epoch is overestimated, showing T*B instead of B, where T is the number of tasks and B is the (actual) number of batches for that epoch. E.g., I see this output with T=2:
Modify MultiTaskDataLoader.__len__ to use the same logic in _get_instances_for_epoch to compute batch numbers. I'm happy to personally open a PR for this.
Problem
When using
MultiTaskDataLoader
with more than one task and theinstances_per_epoch
feature, the number of batches in the epoch is overestimated, showing T*B instead of B, where T is the number of tasks and B is the (actual) number of batches for that epoch. E.g., I see this output with T=2:Steps to Reproduce
Configure an environment that uses
MultiTaskDataLoader
with more than one task and withinstances_per_epoch
set to some integer.Cause
The branch of
MultiTaskDataLoader.__len__
that is called wheninstances_per_epoch is not None
assumes that each dataset will haveself._instances_per_epoch
instances for the epoch, estimating a total ofnum_tasks * self._instances_per_epoch
. However, the implementation ofMultiTaskDataLoader._get_instances_for_epoch
guarantees that all instances across all tasks will approximately sum toself._instances_per_epoch
.Suggested Solution
Modify
MultiTaskDataLoader.__len__
to use the same logic in_get_instances_for_epoch
to compute batch numbers. I'm happy to personally open a PR for this.