apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

dataloader crashes with threads and slow downs with processes #13945

Open mfiore opened 5 years ago

mfiore commented 5 years ago

Description

(sorry, long issue, should it be split in two for multiprocess and threads? I posted them together since I thought they might be related since part of the code is shared)

Hello, I'm trying to train an ssd network using gluoncv. My dataset is a record file loaded with RecordFileDetection and i'm using gluon.data.DataLoader with SSDDefaultTrainTransform (took most of the code from the sample script on gluoncv at https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/ssd/train_ssd.py).

There are some heavy slowdowns while iterating the batch. I've tried with different batch sizes and num workers. If I measure the time to load a batch in the loop, it is normally in the range of 0.02s, but has some random spikes of 4, 5 or even 7 seconds.

I've tried then using thread_pool=True in my dataloader. In this case reading from the record io file makes the program crash.

Environment info (Required)

----------Python Info----------
('Version      :', '2.7.12')
('Compiler     :', 'GCC 5.4.0 20160609')
('Build        :', ('default', 'Nov 12 2018 14:36:49'))
('Arch         :', ('64bit', 'ELF'))
------------Pip Info-----------
('Version      :', '18.1')
('Directory    :', '/home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/pip')
----------MXNet Info-----------
('Version      :', '1.5.0')
('Directory    :', '/home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet')
('Commit Hash   :', 'e8a2b8b9fdafaccbf65397cec142fffcae2289b7')
----------System Info----------
('Platform     :', 'Linux-4.15.0-43-generic-x86_64-with-Ubuntu-16.04-xenial')
('system       :', 'Linux')
('node         :', 'SLVIDPUBN001')
('release      :', '4.15.0-43-generic')
('version      :', '#46~16.04.1-Ubuntu SMP Fri Dec 7 13:31:08 UTC 2018')
----------Hardware Info----------
('machine      :', 'x86_64')
('processor    :', 'x86_64')
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                8
On-line CPU(s) list:   0-7
Thread(s) per core:    2
Core(s) per socket:    4
Socket(s):             1
NUMA node(s):          1
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 158
Model name:            Intel(R) Core(TM) i7-7700K CPU @ 4.20GHz
Stepping:              9
CPU MHz:               4500.178
CPU max MHz:           4500.0000
CPU min MHz:           800.0000
BogoMIPS:              8400.00
Virtualization:        VT-x
L1d cache:             32K
L1i cache:             32K
L2 cache:              256K
L3 cache:              8192K
NUMA node0 CPU(s):     0-7
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp flush_l1d
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0013 sec, LOAD: 0.5820 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0286 sec, LOAD: 0.7725 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0395 sec, LOAD: 0.3612 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0425 sec, LOAD: 0.1815 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0388 sec, LOAD: 0.8965 sec.
Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.2893 sec, LOAD: 0.8236 sec.

Error Message:

Error message when training with thread_pool=True

INFO:root:Started training from [Epoch 0]
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 410 extraneous bytes before marker 0xd9
Corrupt JPEG data: 288 extraneous bytes before marker 0xd9
Corrupt JPEG data: 158 extraneous bytes before marker 0xd9
Corrupt JPEG data: 422 extraneous bytes before marker 0xd9
Corrupt JPEG data: 11 extraneous bytes before marker 0xd9
Corrupt JPEG data: 498 extraneous bytes before marker 0xd9
Corrupt JPEG data: 19227 extraneous bytes before marker 0xd9
Corrupt JPEG data: 3195 extraneous bytes before marker 0xd9
Traceback (most recent call last):
  File "train_ssd.py", line 574, in <module>
    cls_list, summary_writer)
  File "train_ssd.py", line 405, in train
    for i, batch in enumerate(train_data):
  File "/home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/gluon/data/dataloader.py", line 452, in next
    return self.__next__()
  File "/home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/gluon/data/dataloader.py", line 444, in __next__
    batch = ret.get()
  File "/usr/lib/python2.7/multiprocessing/pool.py", line 567, in get
    raise self._value
mxnet.base.MXNetError: [10:55:43] src/io/image_io.cc:146: Check failed: !res.empty() Decoding failed. Invalid image file.

Stack trace returned 10 entries:
[bt] (0) /home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x405bda) [0x7fc1863fcbda]
[bt] (1) /home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x4061f1) [0x7fc1863fd1f1]
[bt] (2) /home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/libmxnet.so(mxnet::io::ImdecodeImpl(int, bool, void*, unsigned long, mxnet::NDArray*)+0x9a6) [0x7fc188d3a746]
[bt] (3) /home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/libmxnet.so(mxnet::io::Imdecode(nnvm::NodeAttrs const&, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> > const&, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> >*)+0xe73) [0x7fc188d3d023]
[bt] (4) /home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x228) [0x7fc188d1a038]
[bt] (5) /home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x2c264d9) [0x7fc188c1d4d9]
[bt] (6) /home/mfiore/.virtualenvs/gluoncv/local/lib/python2.7/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7fc188c1dacf]
[bt] (7) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7fc1b2ec4e40]
[bt] (8) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x2eb) [0x7fc1b2ec48ab]
[bt] (9) /home/mfiore/.virtualenvs/gluoncv/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(_ctypes_callproc+0x48f) [0x7fc1b30d43df]

Minimum reproducible example

(put use_threads=True for the threading issue)

import time
import mxnet as mx
from mxnet import gluon
from mxnet import autograd
from gluoncv import model_zoo
from gluoncv.data import RecordFileDetection
from gluoncv.data.batchify import Tuple, Stack, Pad
from gluoncv.data.transforms.presets.ssd import SSDDefaultTrainTransform

ctx = [mx.gpu(0)]
num_workers = 7
batch_size = 32
height = 512
width= 512
rec_path = 'train.rec'
use_threads = False

train_dataset = RecordFileDetection(rec_path, coord_normalized=True)

net = model_zoo.get_model('ssd_512_mobilenet1.0_voc', pretrained_base=True)

for param in net.collect_params().values():
        if param._data is not None:
            continue
        param.initialize()

with autograd.train_mode():
    _, _, anchors = net(mx.nd.zeros((1, 3, height, width)))
batchify_fn = Tuple(Stack(), Stack(), Stack())  # stack image, cls_targets, box_targets

train_loader = gluon.data.DataLoader(
    train_dataset.transform(SSDDefaultTrainTransform(width, height, anchors)),
    batch_size, True, 
    batchify_fn=batchify_fn, 
    last_batch='rollover', 
    num_workers=num_workers,
    thread_pool=use_threads)

net.hybridize(static_alloc=True)
start_batch_time = time.time()
for i, batch in enumerate(train_loader):
    print("Load batch time is ",time.time()-start_batch_time)
    start_batch_time = time.time()

What have you tried to solve it?

With thread_pool=False: I've tried changing num workers. If I use num_workers=0 the slowdowns don't seem to happen (it's always slow of course =) ). Even with two num workers I start encountering the issue, which doesn't seem to change when I keep increasing them. I've compared the time to iterate on 20 batches with batch_size 32 with the training script from the deprecated mxnet ssd repository, and it's less than half as fast. After measuring the execution time in the code the issue seems to be related to the pickle.loads time on line 443 of dataloader.py

With thread_pool=True I've tried removing prefetching and the error disappear, but the speed is pretty slow (3 sec for batch).

Looking a bit in the dataloader code I've found the following things: 1) there is no result check on the async call, is this correct? 2) If i stop before line 442 ofdataloader.py I see that the self._data_buffer is filled but many results are not successful. Trying to get them gives me one of the following errors:

piyushghai commented 5 years ago

@mfiore Thanks for raising this. I'm labelling it so that other community members can help resolve this.

@mxnet-label-bot Add [Bug, Data-loading, Python]

anirudhacharya commented 5 years ago

ThreadPool support was introduced in this PR https://github.com/apache/incubator-mxnet/pull/13606

cc @zhreshold

zhreshold commented 5 years ago

@mfiore

Is your dataset small, do you know how many batches in each epoch? If it's small, the prefetching step will push all workloads and you will need to wait until the first worker finish its job.

Regarding your question,

  1. The async ret value is a multiprocessing.pool.AsyncResult object, it's only synchronized when ret.get() is called (line 443)
  2. As mentioned in 1., you cannot check the results in self._data_buffer because they are filled with async objects
  3. The problem seem to be related to the "next" statement at line 426, seems like the crash is related to the dataset itself, not dataloader. Since you are using RecordIO, one possible reason is that the record file seek function is not threadsafe, as in this line https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/recordio.py#L268, a temporary solution is to add a mutex to guard this line: https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/recordio.py#L313
mfiore commented 5 years ago

@zhreshold thanks for your answer! The dataset is about 178k examples so with a batch size of 32 it's more than 5k batches. Putting the mutex in the line you pointed solves the crash problem. Now the behaviors with threads or multiprocess is the same, with slow downs every few batches.

zhreshold commented 4 years ago

It's fixed by overriding pickling behavior of RecordIO files without bothering the mutex. I am cloing this now. Please ping me if the problem persists

mathephysicist commented 4 years ago

It's fixed by overriding pickling behavior of RecordIO files without bothering the mutex. I am cloing this now. Please ping me if the problem persists

Can you please elaborate. I am running into a similar issue. How should I override the pickling behavior?

zhreshold commented 4 years ago

@mathephysicist The multiprocessing access to same recordio file is fixed already in master. I created another PR for fixing the multithreading case: https://github.com/apache/incubator-mxnet/pull/18366