Closed 372046933 closed 3 years ago
Hi,
In your code, DALI hangs in ExternalInputIterator waiting infinitely for the next batch of data. ExternalInputIterator next method should raise StopIteration where there is no more data so DALI can process what has left and gracefully finish. Without it is waiting for the next batch of input data. Increasing prefetch_queue_depth just makes the problem more visible as DALI buffers more iterations ahead and faster runs out of the input data. You can fix your code in the following way to get the desired behavior:
import ctypes
import logging
import numpy as np
import os
import queue
import socket
import threading
import time
import torch
from absl import app, flags
from nvidia.dali import pipeline, ops, types
import nvidia.dali.plugin.pytorch
FLAGS = flags.FLAGS
flags.DEFINE_integer('prefetch_queue_depth', 1, 'Dali prefetch_queue_depth', lower_bound=1)
class ImageBatchGrpc(object):
def __init__(self):
self.data = None
self.infdata = None
self.infos = []
self.realsize = 0
def __len__(self):
return len(self.infos)
def _fill_queue(q: queue.Queue, batch_size=2, fill_size=-1):
batchs = int(fill_size / batch_size)
while batchs > 0:
fill_queue_with_batch(q, batch_size=batch_size, realsize=batch_size)
time.sleep(0.5)
batchs -= 1
remains = fill_size % batch_size
if remains > 0:
fill_queue_with_batch(q, batch_size=batch_size, realsize=remains)
def fill_queue_with_batch(q: queue.Queue, batch_size=2, realsize=2):
"""
:param batch_size: `batch_size` should be even
"""
if batch_size % 2 != 0:
logging.error('invalid batch_size:%d', batch_size)
return
images = []
infos = []
with open('/mnt/storage00/taoxu/1.jpg', 'rb') as f:
images.append(np.frombuffer(f.read(), dtype=np.uint8))
infos.append({'key': 1, 'video_id': 12345678, 'idx': 11, 'ip': '8.8.8.8'})
with open('/mnt/storage00/taoxu/2.jpg', 'rb') as f:
images.append(np.frombuffer(f.read(), dtype=np.uint8))
infos.append({'key': 2, 'video_id': 12345678, 'idx': 22, 'ip': '8.8.8.8'})
image_batch_grpc = ImageBatchGrpc()
images = images * int(batch_size / 2)
image_batch_grpc.data = images
infos = infos * int(batch_size / 2)
image_batch_grpc.infos = infos
image_batch_grpc.realsize = realsize
logging.info('Fill queue batch len:%d realsize:%d', len(image_batch_grpc), image_batch_grpc.realsize)
q.put(image_batch_grpc)
class ExternalInputIterator(object):
def __init__(self, batch_size, reader: queue.Queue):
self.batch_size = batch_size
self.real_reader = reader
self.last = False
def __iter__(self):
return self
def __next__(self):
if self.last:
raise StopIteration
#logging.info('before calling __next__')
image_audit_request_key = []
image_audit_request_video_id = []
image_audit_request_idx = []
real_size = []
ip = []
try:
image_batch = self.real_reader.get()
except:
logging.exception('Uncaught exception')
raise StopIteration
for info_d in image_batch.infos:
image_audit_request_key.append(info_d['key'])
image_audit_request_video_id.append(info_d['video_id'])
image_audit_request_idx.append(info_d['idx'])
real_size.append(image_batch.realsize)
ip.append(info_d['ip'])
self.last = image_batch.realsize < len(image_batch.infos)
image_jpeg = image_batch.data
image_batch.data = None
image_jpeg = [np.frombuffer(b, dtype=np.uint8) for b in image_jpeg]
real_size = np.array(real_size)
image_audit_request_key = np.array(image_audit_request_key, dtype=np.int32)
image_audit_request_video_id = np.array(image_audit_request_video_id, dtype=np.int64)
image_audit_request_idx = np.array(image_audit_request_idx)
ip = np.array([np.frombuffer(socket.inet_aton(s), dtype=np.uint8) for s in ip])
#logging.info('after calling __next__')
return image_jpeg, real_size, image_audit_request_key, image_audit_request_video_id, image_audit_request_idx, ip
class DataPipeline(pipeline.Pipeline):
# stride should be loaded from model
# TODO prefetch_queue_depth will effect speed?
def __init__(self, batch_size, num_threads, device_id, stride, q: queue.Queue):
"""
:param batch_size:
:param num_threads:
:param device_id:
:param stride:
"""
super().__init__(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
prefetch_queue_depth=FLAGS.prefetch_queue_depth, seed=42)
self.stride = stride
self.resize = 320
self.max_size = 480
self.mean = [255. * x for x in [0.485, 0.456, 0.406]]
self.std = [255. * x for x in [0.229, 0.224, 0.225]]
self.source = ops.ExternalSource(
source=ExternalInputIterator(batch_size=batch_size, reader=q),
num_outputs=6)
self.decode_infer = ops.ImageDecoder(device="mixed", output_type=types.RGB)
self.resize_infer = ops.Resize(device='gpu', interp_type=types.DALIInterpType.INTERP_CUBIC,
resize_shorter=self.resize, max_size=self.max_size, save_attrs=True)
self.normalize = ops.CropMirrorNormalize(device='gpu', mean=self.mean, std=self.std)
self.pad = ops.Pad(device='gpu', axis_names='HW', align=self.stride, shape=[-1, -1])
def define_graph(self):
image_jpeg, real_size, image_audit_request_key, image_audit_request_video_id, image_audit_request_idx, ip = \
self.source()
# Layout after mixed Decoder is HWC
images = self.decode_infer(image_jpeg)
images, attrs = self.resize_infer(images)
resized_images = images
images = self.pad(self.normalize(images))
return images, attrs, resized_images, real_size, image_audit_request_key, image_audit_request_video_id, \
image_audit_request_idx, ip
class DataProvider():
"""Data loader for data parallel using Dali"""
def __init__(self, stride, q: queue.Queue):
logging.info('Using DALI provider')
self.stride = stride
self.batch_size = 32
self.pipe = DataPipeline(batch_size=self.batch_size, num_threads=6,
device_id=torch.cuda.current_device(),
stride=self.stride, q=q)
self.pipe.build()
def get_data_batch(self):
try:
# dali_data, dali_resize_img is TensorListGPU
# Others are TensorListCPU
dali_data, dali_attrs, dali_resize_img, real_size, request_key, request_video_id, request_idx, ip = \
self.pipe.run()
except StopIteration:
logging.warning('Caught StopIteration')
return None
dali_data = dali_data.as_tensor()
torch_tensor = torch.zeros(dali_data.shape(), dtype=torch.float, device=torch.device('cuda'))
nvidia.dali.plugin.pytorch.feed_ndarray(dali_data, torch_tensor)
dali_attrs = dali_attrs.as_cpu().as_array()
real_size = real_size.as_array()
request_key = request_key.as_array()
request_video_id = request_video_id.as_array()
request_idx = request_idx.as_array()
ip = ip.as_array()
resized_size = np.empty((self.batch_size, 2))
for i in range(self.batch_size):
resized_size[i] = dali_resize_img[i].shape()[:2]
ratios = np.max(resized_size, axis=1) / np.max(dali_attrs, axis=1)
batch = ImageBatchGrpc()
batch.infos = []
for _ in range(self.batch_size):
item = {}
item['req'] = {'key': request_key[_], 'video_id': request_video_id[_], 'idx': request_idx[_]}
item['ip'] = socket.inet_ntoa(ip[_].tobytes())
batch.infos.append(item)
batch.realsize = real_size[0]
batch.infdata = torch_tensor
batch.otherdata = ratios
logging.info('dali get_data_batch len:%d, realsize:%d', len(batch), batch.realsize)
return batch
def main(argv):
q = queue.Queue()
fill_queue_thread = threading.Thread(target=_fill_queue, args=(q,), kwargs={'batch_size': 32, 'fill_size': 169})
fill_queue_thread.start()
eii = ExternalInputIterator(batch_size=32, reader=q)
data_provider = DataProvider(stride=1, q=q)
while True:
batch = data_provider.get_data_batch()
if batch is None:
break
logging.info('Batch len:%d, realsize:%d', len(batch), batch.realsize)
if __name__ == '__main__':
app.run(main)
Thanks for your reply. In fact, the program I pasted is extracted from a server program. It should continue running in the background. As long as data is available, it returns the batch to outer inference service. The incoming data never stops. So I changed prefetch_queue_depth
to 1, is this the correct or appropriate solution?
As your pipeline work on data provided online prefetch_queue_depth to 1
should do, as there is no point in processing more data ahead of time.
If data comes faster than backend processing, is it meaningful to process data ahead of time?
In some cases it makes sense as you can overlap data processing with backend processing, the downside is that you cannot consume it until more data comes.
Thanks. Will future version support Pipeline.run
when the prefetch queue is not fully filled?
Thanks. Will future version support Pipeline.run when the prefetch queue is not fully filled?
I would say that it is somehow possible now. Instead of using run
method you can use schedule_run
, share_outputs
and release_outputs
.
You can call schedule_run
as many times as you want, it would just queue more work in DALI for processing, and call share_outputs
and release_outputs
when you want to get the data from DALI. Keep in mind that if there is no data in DALI share_outputs
will raise StopIteration.
When
--prefetch_queue_depth=6
is passed to the following program. It get stuck atPipeline.run()
.However, when
--prefetch_queue_depth=1
is passed, the pipeline just works.IMHO, the pipeline should behave the same regardless of
prefetch_queue_depth
. If the input pipeline is slower thanPipeline.run
, why not run eagerly?