pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.12k stars 149 forks source link

Importing `torchdata.stateful_dataloader` hides `torch` RandomSampler and BatchSampler #1280

Closed byi8220 closed 3 months ago

byi8220 commented 3 months ago

🐛 Describe the bug

Description

In torchdata.stateful_dataloader.sampler.py, several Sampler classes in torch.utils.data are overwritten:

  1. https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L61-L62
  2. https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L134-L135

The implication here is that if code were to import StatefulDataLoader after importing torch, then there may be inconsistent definitions of BatchSampler and RandomSampler at runtime. See the gist below for a toy example, where a StatefulDataLoader has a handle to a torch.utils.data.sampler.BatchSampler rather than a torchdata.stateful_dataloader.sampler.BatchSampler.

This may possibly be the root cause of https://github.com/huggingface/accelerate/issues/2894

How to reproduce

See gist: https://gist.github.com/byi8220/3091215e38d8f1caba01bc015aed32aa

Versions

PyTorch version: 2.5.0.dev20240628 Is debug build: False CUDA used to build PyTorch: Could not collect ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04 LTS (x86_64) GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.39

Python version: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-6.8.0-36-generic-x86_64-with-glibc2.39 Is CUDA available: False CUDA runtime version: 12.5.40 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Ti Nvidia driver version: 555.42.02 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 43 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 12 On-line CPU(s) list: 0-11 Vendor ID: AuthenticAMD Model name: AMD Ryzen 5 3600 6-Core Processor CPU family: 23 Model: 113 Thread(s) per core: 2 Core(s) per socket: 6 Socket(s): 1 Stepping: 0 Frequency boost: enabled CPU(s) scaling MHz: 83% CPU max MHz: 4208.2031 CPU min MHz: 2200.0000 BogoMIPS: 7200.35 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es Virtualization: AMD-V L1d cache: 192 KiB (6 instances) L1i cache: 192 KiB (6 instances) L2 cache: 3 MiB (6 instances) L3 cache: 32 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-11 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection Vulnerability Spec rstack overflow: Mitigation; Safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] torch==2.5.0.dev20240628 [pip3] torchdata==0.7.1.dev20240627 [conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] pytorch 2.5.0.dev20240628 py3.12_cpu_0 pytorch-nightly [conda] pytorch-mutex 1.0 cpu pytorch-nightly [conda] torchdata 0.7.1.dev20240627 py312 pytorch-nightly

andrewkho commented 3 months ago

Thanks for flagging @byi8220 ! We put this patch in to give default samplers the faster behaviour by default. I'll look at pulling this out to a non-monkey patch solution

andrewkho commented 3 months ago

cc @gokulavasan see the original issue in HF Accelerate on assumptions on number of calls to iter: https://github.com/huggingface/accelerate/issues/2894

andrewkho commented 3 months ago

I see 3 approaches currently to fix this,

1) we fork the dataloader init code, this is maybe the best way but introduces more forked code 2) we repro the sampler/batch_sampler set up in init before we call super().init(...), this seems hacky and not significantly better than 1) 3) we check isinstance after super().init() and replace, but that heading down old-school Lightning territory and surely will lead to head-scratching and further problems down the road.

I'm going to go with 1)

byi8220 commented 3 months ago

Thanks for resolving this! This appears to have fixed the breaking tests in accelerate, and the repro above shows there is no more monkey patching going on.

However, it might be worth mentioning that if one passes in an existing non-stateful BatchSampler to the StatefulDataLoader constructor, then the constructed StatefulDataLoader will use the provided sampler. Only pointing this out since it wasn't clear to me if this is the intended behavior.

Repro output with nightly 0.7.1.dev20240703+cpu (important stuff in green diff):

--------------------------------------------------------------------------------
BatchSampler before importing `stateful_dataloader`: <class 'torch.utils.data.sampler.BatchSampler'>
RandomSampler before importing `stateful_dataloader`: <class 'torch.utils.data.sampler.RandomSampler'>
--------------------------------------------------------------------------------
type(non_stateful_dataloader.batch_sampler): <class 'torch.utils.data.sampler.BatchSampler'>
type(non_stateful_dataloader.batch_sampler.sampler): <class '__main__.MyRandomSamplerWrapper'>
type(non_stateful_dataloader.batch_sampler.sampler.original_sampler): <class 'torch.utils.data.sampler.RandomSampler'>
--------------------------------------------------------------------------------
+ BatchSampler after importing `stateful_dataloader`: <class 'torch.utils.data.sampler.BatchSampler'>
+ RandomSampler after importing `stateful_dataloader`: <class 'torch.utils.data.sampler.RandomSampler'>
--------------------------------------------------------------------------------
+ # Even after the fix, stateful_dataloader has a non-stateful sampler
type(stateful_dataloader.batch_sampler): <class 'torch.utils.data.sampler.BatchSampler'>
type(stateful_dataloader.batch_sampler.sampler): <class '__main__.MyRandomSamplerWrapper'>
type(stateful_dataloader.batch_sampler.sampler.original_sampler): <class 'torch.utils.data.sampler.RandomSampler'>
--------------------------------------------------------------------------------
type(stateful_dataloader_2.batch_sampler): <class 'torchdata.stateful_dataloader.sampler.BatchSampler'>
type(stateful_dataloader_2.batch_sampler.sampler): <class 'torch.utils.data.sampler.SequentialSampler'>
--------------------------------------------------------------------------------
andrewkho commented 3 months ago

@byi8220 thanks for the details! If I understand correctly: user has explicitly passed in a non-stateful Batch Sampler to DataLoader constructor? In this case, I think this is correct, we should respect what the user has passed in and not try to override it under the hood. I've seen code that does this before in other libraries and it can cause some nasty surprises and un-intuitive behaviour, and can be very hard to debug.

byi8220 commented 3 months ago

user has explicitly passed in a non-stateful Batch Sampler to DataLoader constructor?

Yes

In this case, I think this is correct, we should respect what the user has passed in and not try to override it under the hood.

This makes sense, I was just curious if this would cause an issue with saving or loading a state dict into this dataloader

andrewkho commented 3 months ago

That's a great call, yes it might cause an issue if checkpoint was saved before and then loaded with the new code-change, but it might also just try to fast-forward, I haven't attempted. We'll be cutting a release this month, so once that's out it should be easier to manage these types of changes.

For the case where users are explicitly passing in a BatchSampler, they can import it from torchdata.stateful_dataloader.samplers instead of from torch.utils.data

byi8220 commented 3 months ago

I'm not super familiar with the stateful_dataloader code, but I agree with your hunch that it should fall back to fast forwarding.

Still, it might be worth adding a warning? Worst case, it might help catch a hard to find bug. But best case, isn't fast forwarding an iterator costly?