open-mmlab / mmengine

OpenMMLab Foundational Library for Training Deep Learning Models
https://mmengine.readthedocs.io/
Apache License 2.0
1.18k stars 356 forks source link

[Feature] Speed up the resume process of IterBased loop #1520

Open YinAoXiong opened 8 months ago

YinAoXiong commented 8 months ago

What is the feature?

https://github.com/open-mmlab/mmengine/blob/2c4516c62294964065d058d98799402f50afdef6/mmengine/runner/loops.py#L281 现有的恢复方式会对dataloader 迭代 n 个step,当n较大时,速度会很慢,因为执行了实际的数据加载和处理逻辑。 是否有比较好的方式只迭代index,不执行实际的数据加载流程。

  1. 一种可能的方式是和用户约定一个返回虚拟数据的数据集接口,在恢复时返回虚拟数据,

    
    class Dataset:
    
    def __getitem__(self, index):
        if self._skip_flag:
            return # Fake data
        # 处理数据
        return Real data
    
    def skip(self):
        self._skip_flag = True
    
    def resume(self):
        self._skip_flag = False

loop中的处理逻辑

        if (
            hasattr(self.dataloader.dataset, "skip")
            and callable(self.dataloader.dataset.skip)
            and hasattr(self.dataloader.dataset, "resume")
            and callable(self.dataloader.dataset.resume)
        ):
            self.dataloader.dataset.skip()
            for _ in range(self._iter):
                next(self.dataloader_iterator)
            self.dataloader.dataset.resume()
        else:
            for _ in range(self._iter):
                next(self.dataloader_iterator)
2. 方式一还是需要用户进行配合,是否可以对dataloader进行操作从而无感知的快速跳过?
```python
                iter_batch_sampler = iter(self.dataloader.batch_sampler)
                for _ in range(self._iter):
                    next(iter_batch_sampler)

尝试直接迭代batch_sampler 在worker=0的时候是正常的,在多worker的时候恢复数据顺序出现错误。 像知道有没有什么比较好的解决方案

Any other context?

https://discuss.pytorch.org/t/is-there-any-way-to-skip-steps-in-a-dataloader/123201 https://pytorch.org/data/main/dataloader2.html

Snapshot the state of data-preprocessing pipeline (WIP)

zhouzaida commented 6 months ago

一个最小改动的方案是在迭代前 mock dataset 的__getitem__方法:

    def run(self) -> None:
        """Launch training."""
        self.runner.call_hook('before_train')
        # In iteration-based training loop, we treat the whole training process
        # as a big epoch and execute the corresponding hook.
        self.runner.call_hook('before_train_epoch')
        if self._iter > 0:
            print_log(
                f'Advance dataloader {self._iter} steps to skip data '
                'that has already been trained',
                logger='current',
                level=logging.WARNING)
            # mock
            old_getitem = self.dataloader_iterator.dataset.__getitem__
            self.dataloader_iterator.dataset.__getitem__ = a_new_getitem_method
            for _ in range(self._iter):
                next(self.dataloader_iterator)
            self.dataloader_iterator.dataset.__getitem__ = old_getitem
chtzs commented 6 months ago

I believe this PR is the cause of the issue: https://github.com/open-mmlab/mmengine/pull/1471. While it fixed the resume iteration problem, it also led to slow resume speed. A suitable solution would be to call the _next_index() method of the DataLoader's built-in iterator to skip a batch without reading the data.