Open zhengwy888 opened 1 year ago
I tried all these versions, the only version that worked was the last one, but it's too hacky. Is there a better way?
dp = IterableWrapper(list(range(20))) dp = dp.shuffle() items = [] rs = InProcessReadingService() dl = DataLoader2(dp, reading_service=rs) iter1 = iter(dl) for _ in range(4): next(iter1) # 16 elements left in dl state = dl.state_dict() dl2 = DataLoader2.from_state(state, reading_service=rs) # assert len(list(dl2)) == 20 - 4 # got 20 dp2 = deserialize_datapipe(serialize_datapipe(dl.datapipe)) # assert len(list(dp2)) == 20 - 4 # got 20 dp3 = deserialize_datapipe(serialize_datapipe(dl.datapipe)) _simple_graph_snapshot_restoration(dp3, dp3._number_of_samples_yielded) ret3 = list(dp3) assert len(ret3) == 20 - 4 # but content is not the same dl4 = DataLoader2.from_state(state, reading_service=rs) _simple_graph_snapshot_restoration(dl4.datapipe, dl.datapipe._number_of_samples_yielded) ret4 = list(dl4) assert len(ret4) == 20 - 4 # but content is not the same dp5 = deserialize_datapipe(serialize_datapipe(dl.datapipe)) pipes = get_all_pipes(dp5) for pipe in pipes: if isinstance(pipe, ShufflerIterDataPipe): buffer_cache = pipe._buffer[:] assert len(buffer_cache) == 20 - 4 rng_state = pipe._rng.getstate() _simple_graph_snapshot_restoration(dp5, dl.datapipe._number_of_samples_yielded) dp5._buffer = buffer_cache[:] dp5._rng.setstate(rng_state) it5 = iter(dp5) ret5 = list(it5) assert len(ret5) == 20 - 4 expected = list(iter1) # ret5 is the only method that worked # assert ret3 == expected # assert ret4 == expected assert ret5 == expected
PyTorch version: 2.0.0a0+gite9ebda2 Is debug build: False CUDA used to build PyTorch: 12.0 ROCM used to build PyTorch: N/A OS: Ubuntu 20.04.3 LTS (x86_64) GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0 Clang version: 12.0.1 (https://github.com/conda-forge/clangdev-feedstock d44358f44aef33e9fa7c5f93e2481ee8f1a04ab6) CMake version: version 3.19.1 Libc version: glibc-2.31 Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10) [GCC 10.3.0] (64-bit runtime) Python platform: Linux-5.4.0-64-generic-x86_64-with-glibc2.10 Is CUDA available: False CUDA runtime version: 12.0.140 GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: False Versions of relevant libraries: [pip3] mypy-extensions==1.0.0 [pip3] mypy-protobuf==3.3.0 [pip3] numpy==1.23.5 [pip3] pytorch3d==0.6.2 [pip3] torch==2.0.1+1684801906.cuda120.cudnn891.nccl218.ap [pip3] torch-mlir==1684442443 [pip3] torch-scatter==2.1.0 [pip3] torch-tb-profiler==0.4.1 [pip3] torchdata==0.7.0.dev20230601 [pip3] torchfile==0.1.0 [pip3] torchvision==0.15.1a0+42759b1 [conda] magma-cuda121 2.6.1 1 pytorch [conda] mkl 2020.4 h726a3e6_304 conda-forge [conda] mkl-include 2023.1.0 h84fe81f_48680 conda-forge [conda] numpy 1.23.5 py38h7042d01_0 conda-forge [conda] pytorch3d 0.6.2 pypi_0 pypi [conda] torch 2.0.1+1684801906.cuda120.cudnn891.nccl218.ap pypi_0 pypi [conda] torch-mlir 1684442443 pypi_0 pypi [conda] torch-scatter 2.1.0 pypi_0 pypi [conda] torch-tb-profiler 0.4.1 pypi_0 pypi [conda] torchfile 0.1.0 pypi_0 pypi [conda] torchvision 0.15.1a0+42759b1 pypi_0 pypi
I think you can rely on the dlv2.state_dict() to get the state. But, it's still in prototyping mode it might has some Errors.
dlv2.state_dict()
but it didn't work, see example 1.
🐛 Describe the bug
I tried all these versions, the only version that worked was the last one, but it's too hacky. Is there a better way?
Versions