pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.12k stars 149 forks source link

Crash when resetting datapipe with bufffer of filehandles #1161

Open falckt opened 1 year ago

falckt commented 1 year ago

🐛 Describe the bug

When the datapipe iterator is reset, the multiprocessing reading service tries to pickle the datapipe (why?). In case the data pipe contains a buffer with file handles this fails. A MWE

# curl -O http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar

import torchdata.datapipes as dp
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService

# strip filehandles
def decode(data):
    return {key: key for key, _value in data.items()}

webds_dp = (
    dp.iter.FileLister(".", masks="openimages*.tar")
    .open_files("rb")
    .load_from_tar()
    .webdataset()
    .shuffle(buffer_size=50)
    .header(5)  # to trigger crash faster
    .map(decode)
)

dl = DataLoader2(
    webds_dp,
    reading_service=MultiProcessingReadingService(4),
)

for _ in dl:
    pass

print("first pass completed")

# reset dataloader
for _ in dl:
    pass

I can avoid the crash decoding the data before feeding them into the buffer, but would like to delay the expensive decode till after the buffer to have the datapipe load faster after initialization.

Traceback

first pass completed
Process ForkProcess-3:
Process ForkProcess-1:
Process ForkProcess-2:
Process ForkProcess-4:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/eventloop.py", line 133, in DataPipeToQueuesLoop
    for _ in loop:
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/iter.py", line 153, in DataPipeBehindQueues
    source_datapipe = request.reset_fn(source_datapipe)
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/utils/worker.py", line 164, in process_reset_fn
    graph = traverse_dps(datapipe)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 98, in traverse_dps
    return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 140, in _traverse_helper
    items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 67, in _list_connected_datapipes
    p.dump(scan_obj)
TypeError: cannot pickle 'ExFileObject' object
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/eventloop.py", line 133, in DataPipeToQueuesLoop
    for _ in loop:
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/iter.py", line 153, in DataPipeBehindQueues
    source_datapipe = request.reset_fn(source_datapipe)
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/utils/worker.py", line 164, in process_reset_fn
    graph = traverse_dps(datapipe)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 98, in traverse_dps
    return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/eventloop.py", line 133, in DataPipeToQueuesLoop
    for _ in loop:
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/eventloop.py", line 133, in DataPipeToQueuesLoop
    for _ in loop:
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/iter.py", line 153, in DataPipeBehindQueues
    source_datapipe = request.reset_fn(source_datapipe)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 140, in _traverse_helper
    items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/iter.py", line 153, in DataPipeBehindQueues
    source_datapipe = request.reset_fn(source_datapipe)
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/utils/worker.py", line 164, in process_reset_fn
    graph = traverse_dps(datapipe)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 67, in _list_connected_datapipes
    p.dump(scan_obj)
  File "/opt/conda/lib/python3.10/site-packages/torchdata/dataloader2/utils/worker.py", line 164, in process_reset_fn
    graph = traverse_dps(datapipe)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 98, in traverse_dps
    return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 98, in traverse_dps
    return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
TypeError: cannot pickle 'ExFileObject' object
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 140, in _traverse_helper
    items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 140, in _traverse_helper
    items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 67, in _list_connected_datapipes
    p.dump(scan_obj)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 67, in _list_connected_datapipes
    p.dump(scan_obj)
TypeError: cannot pickle 'ExFileObject' object
TypeError: cannot pickle 'ExFileObject' object

Versions

Test run in pytorch-nightly docker image. (Also fails on pytorch 2.0 with torchdata 0.6.)

PyTorch version: 2.1.0.dev20230514 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64) GCC version: Could not collect Clang version: Could not collect CMake version: version 3.22.1 Libc version: glibc-2.31

Python version: 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.19.0-1024-aws-x86_64-with-glibc2.31 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 48 bits physical, 48 bits virtual CPU(s): 64 On-line CPU(s) list: 0-63 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 1 NUMA node(s): 1 Vendor ID: AuthenticAMD CPU family: 23 Model: 49 Model name: AMD EPYC 7R32 Stepping: 0 CPU MHz: 2799.972 BogoMIPS: 5599.94 Hypervisor vendor: KVM Virtualization type: full L1d cache: 1 MiB L1i cache: 1 MiB L2 cache: 16 MiB L3 cache: 128 MiB NUMA node0 CPU(s): 0-63 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid

Versions of relevant libraries: [pip3] numpy==1.24.3 [pip3] torch==2.1.0.dev20230514 [pip3] torchaudio==2.1.0.dev20230514 [pip3] torchdata==0.7.0.dev20230514 [pip3] torchelastic==0.2.2 [pip3] torchtext==0.16.0.dev20230514 [pip3] torchvision==0.16.0.dev20230514 [pip3] triton==2.1.0 [conda] blas 1.0 mkl [conda] mkl 2023.1.0 h6d00ec8_46342 [conda] mkl-service 2.4.0 py310h5eee18b_1 [conda] mkl_fft 1.3.6 py310h1128e8f_1 [conda] mkl_random 1.2.2 py310h1128e8f_1 [conda] numpy 1.24.3 py310h5f9d8c6_1 [conda] numpy-base 1.24.3 py310hb5e798b_1 [conda] pytorch 2.1.0.dev20230514 py3.10_cuda11.7_cudnn8.5.0_0 pytorch-nightly [conda] pytorch-cuda 11.7 h778d358_5 pytorch-nightly [conda] pytorch-mutex 1.0 cuda pytorch-nightly [conda] torchaudio 2.1.0.dev20230514 py310_cu117 pytorch-nightly [conda] torchdata 0.7.0.dev20230514 py310 pytorch-nightly [conda] torchelastic 0.2.2 pypi_0 pypi [conda] torchtext 0.16.0.dev20230514 py310 pytorch-nightly [conda] torchtriton 2.1.0+7d1a95b046 py310 pytorch-nightly [conda] torchvision 0.16.0.dev20230514 py310_cu117 pytorch-nightly

ejguan commented 1 year ago

Thanks for reporting it. Similar to the solution for https://github.com/pytorch/data/issues/1150, we need a special wrapper to wrap the object to indicate we don't want it to be traversed during graph retrieval.

josiahls commented 1 year ago

I have a similar issue for reinforcement learning using the in_memory_cache. We basically have a pipe that has a buffer with un-picklable objects. Below is a minimal example:

import gymnasium as gym
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps

def make_env(env:str):
    print(env)
    return gym.make(env,render_mode='rgb_array')
def reset_env(env):
    env.reset()
    env.render()
    return env

pipe = dp.iter.IterableWrapper(['CartPole-v1']*3)
pipe = pipe.map(make_env)
# Once we do a full iteration of envs, we dont want to be creating new ones
# so we mem cache them and cycle through them 
pipe = dp.iter.InMemoryCacheHolder(pipe) 
pipe = pipe.cycle()
pipe = pipe.map(reset_env)
pipe = pipe.header(10)
traverse_dps(pipe) # <- works fine since we haven't init any complex env rendering
for o in pipe:pass    # Load everything into the memory cache holder
traverse_dps(pipe)  # Traverse again and causes an error

Outputs:

File [/usr/local/lib/python3.8/dist-packages/torch/utils/data/graph.py:67](https://vscode-remote+dev-002dcontainer-002b2f686f6d652f6a6f736961682f5079636861726d50726f6a656374732f66617374726c.vscode-resource.vscode-cdn.net/usr/local/lib/python3.8/dist-packages/torch/utils/data/graph.py:67), in _list_connected_datapipes(scan_obj, only_datapipe, cache)
     65         cls.set_getstate_hook(getstate_hook)
     66 try:
---> 67     p.dump(scan_obj)
     68 except (pickle.PickleError, AttributeError, TypeError):
     69     if DILL_AVAILABLE:

TypeError: cannot pickle 'pygame.Surface' object

My solution is:

class PickleableInMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]):
...
    def __getstate__(self):
        state = (
            self.source_dp,
            self.size
        )
        if IterDataPipe.getstate_hook is not None:
            return IterDataPipe.getstate_hook(state)
        return state

    def __setstate__(self, state):
        (
            self.source_dp,
            self.size
        ) = state
        self.cache: Optional[Deque] = None
        self.idx: int = 0

There are issues with this since the cache will be wiped, however I don't see an alternative to this since the cache would contain init envs. In fact this might be prefered behavior since it will need to reinit these envs, and having the cache empty will force it to call the self.source_dp to re-init.

pableeto commented 1 year ago

I've met a similar problem that caused by the pickle operation during reset. I tried to figure out the logic of resetting dataloaders but failed. May I know why there is such a pickle operation during graph traverse?