Closed Wilann closed 3 years ago
You don't need to use decoder operator, VideoReader a batch of raw sequences. Something like this should work in your case:
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import os
vid_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/video/cfr_ntsc_29_97_test.mp4')
batch_size = 1
sequence_length = 8
initial_prefetch_size = 16
num_iterations = 5
class VideoPipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data, shuffle):
super(VideoPipeline, self).__init__(batch_size, num_threads, device_id, seed=0)
self.input = ops.VideoReader(device='gpu', filenames=data, sequence_length=sequence_length,
shard_id=0, num_shards=1,
random_shuffle=shuffle,
initial_fill=initial_prefetch_size)
self.resize = ops.Resize(size=(720,720), device='gpu')
self.grayscale = ops.CoordTransform(M=[0.35, 0.5, 0.15] * 3, dtype=types.UINT8, device='gpu')
self.normalize = ops.Normalize(mean=0.5, stddev=0.5, device='gpu')
def define_graph(self):
frames = self.input(name='Reader')
frames = self.resize(frames)
frames = self.grayscale(frames)
frames = self.normalize(frames)
return frames
video_pipe = VideoPipeline(batch_size=batch_size, num_threads=2, device_id=0, data=vid_path, shuffle=False)
pipes = video_pipe
pipes.build()
dali_iterator = DALIGenericIterator(pipes, ['data'], pipes[0].epoch_size('Reader'))
@Wilann CoordTransform is a very versatile operator. It basically allows you to apply a linear transformation to the innermost dimension of your data. In this case, it's channels. You can apply a grayscale matrix:
M = [0.35, 0.5, 0.15]
if you want to preserver 3 channels, but just make them uniform:
M = [0.35, 0.5, 0.15] * 3
A random saturation is a bit more involved, but quite possible (I use the new functional API here, which I also encourage you to try):
gray = np.float32([[0.35, 0.5, 0.15]] * 3)
id = np.identity(3, dtype=np.float32)
M = dali.fn.uniform(range=(0, 1)) * (gray - id) + id
frames = dali.fn.coord_transform(frames, M=M, dtype=types.UINT8, device='gpu')
@JanuszL Thank you for the modified code!
@mzient Thank you for the explanation of CoordTransform
and its other use-cases!
Strangely, when using the iterator to load data into my model, it's getting stuck at sequence 103/1329 every time. This is a binary classification problem. I'm trying to pass frames to my model where if it predicts the frame to be of class 0, I save that frame number. Does this have to do with the VideoPipeline?
Here's the modified pipeline:
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import nvidia.dali.ops as ops
import nvidia.dali.types as types
batch_size = 1
sequence_length = 100
initial_prefetch_size = 100
class VideoPipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data, shuffle):
super(VideoPipeline, self).__init__(batch_size, num_threads, device_id, seed=0)
self.input = ops.VideoReader(device='gpu', filenames=data, sequence_length=sequence_length,
shard_id=0, num_shards=1,
random_shuffle=shuffle, initial_fill=initial_prefetch_size)
self.resize = ops.Resize(size=(720,720), device='gpu')
# Apply a grayscale matrix linear transformation to the innermost dimension of your data - the channels
self.grayscale = ops.CoordTransform(M=[0.35, 0.5, 0.15], dtype=types.UINT8, device='gpu')
self.normalize = ops.Normalize(mean=0.5, stddev=0.5, device='gpu')
def define_graph(self):
frames = self.input(name='Reader')
frames = self.resize(frames)
frames = self.grayscale(frames)
frames = self.normalize(frames)
return frames
video_pipe = VideoPipeline(batch_size=batch_size, num_threads=4, device_id=0, data=vid_path, shuffle=False)
pipes = [video_pipe]
pipes[0].build()
dali_iterator = DALIGenericIterator(pipes, ['data'], pipes[0].epoch_size('Reader'))
num_sequences = dali_iterator.size
print('num_sequences:', num_sequences)
It's output:
num_sequences: 1329
Here's the inference code:
def get_crop_frames(net, vid_path, device):
'''
Get frames to crop using a neural network on GPU
Note: `start_frame` and `end_frame` are used to reduce computation time
Parameters:
net: Neural network
vid_path: Path to original video
-
'''
start = time.time()
end = time.time()
total_process_time = 0
crop_frames = []
net.eval()
zero_tensor = torch.zeros(1, device=device)
# Iterate over all batches
for i, data in enumerate(dali_iterator, 1):
data = torch.squeeze(data[0]['data'], 0) # (1, 8, 720, 720, 1) --> (8, 720, 720, 1)
data = data.transpose(1, 3) # (8, 720, 720, 1) --> (8, 1, 720, 720)
# Get log-probabilities
log_probs = net(data)
# Convert probabilities to predicted class
_, class_preds = log_probs.data.max(1, keepdims=True)
class_preds = np.squeeze(class_preds)
# If the model thinks this frame is at the high-angle, save the timestamp
for idx in range(0, len(data)):
if class_preds[idx] == zero_tensor:
crop_frames.append((i-1)*sequence_length + idx)
# Print statistics
seq_process_time = time.time() - end
total_process_time += seq_process_time
average_process_time = total_process_time / i
num_seq_left = num_sequences - i
estimated_time_left = average_process_time * num_seq_left
estimated_time_left_h = int(estimated_time_left // 60 // 60)
estimated_time_left_min = int((estimated_time_left - estimated_time_left_h*60*60) // 60)
estimated_time_left_sec = int((estimated_time_left - estimated_time_left_min*60) % 60)
end = time.time()
time_h = int((end-start) // 60 // 60)
time_min = int(((end-start) - time_h*60*60) // 60)
time_sec = round(((end-start) - time_min*60) % 60)
frame_num = sequence_length * i
sys.stdout.write('\rProcessing Sequence: {:.2f}% {}/{} Frames Saved: {:.2f}% {}/{} Time Elapsed: {} h {} min {} s ETA: {} h {} m {} s'
.format(i*100/num_sequences, i, num_sequences,
len(crop_frames)*100/frame_num, len(crop_frames), frame_num,
time_h, time_min, time_sec,
estimated_time_left_h, estimated_time_left_min, estimated_time_left_sec))
sys.stdout.flush()
return crop_frames
device = torch.device('cuda')
net = Net(sequence_length).to(device)
net.load_state_dict(torch.load(model_path))
print('Device:', device)
print('Model Path:', model_path)
print(net)
crop_frames = get_crop_frames(net, vid_path, device)
And the output:
Device: cuda
Model Path: ./models/model_v2_loss=0.0061.pt
Net(
(conv1): Conv2d(1, 6, kernel_size=(10, 10), stride=(2, 2), padding=(10, 10))
(conv2): Conv2d(6, 9, kernel_size=(10, 10), stride=(2, 2), padding=(10, 10))
(conv3): Conv2d(9, 12, kernel_size=(10, 10), stride=(2, 2), padding=(10, 10))
(conv4): Conv2d(12, 15, kernel_size=(10, 10), stride=(2, 2), padding=(10, 10))
(conv5): Conv2d(15, 18, kernel_size=(10, 10), stride=(2, 2), padding=(10, 10))
(conv6): Conv2d(18, 21, kernel_size=(10, 10), stride=(2, 2), padding=(10, 10))
(conv7): Conv2d(21, 24, kernel_size=(10, 10), stride=(2, 2), padding=(10, 10))
(fc1): Linear(in_features=6936, out_features=3468, bias=True)
(fc2): Linear(in_features=3468, out_features=1734, bias=True)
(fc3): Linear(in_features=1734, out_features=867, bias=True)
(fc4): Linear(in_features=867, out_features=433, bias=True)
(fc5): Linear(in_features=433, out_features=216, bias=True)
(fc6): Linear(in_features=216, out_features=108, bias=True)
(fc7): Linear(in_features=108, out_features=54, bias=True)
(fc8): Linear(in_features=54, out_features=2, bias=True)
)
Processing Sequence: 7.75% 103/1329 Frames Saved: 0.70% 72/10300 Time Elapsed: 0 h 0 min 24 s ETA: 0 h 4 m 42 s
Can you check if your video has a variable frame rate? If so that is the reason, as DALI doesn't support such videos now. There is a check that should warn you about this but it is based on the heuristics that may yield the wrong result.
Running from a notebook,
cmd = 'ffmpeg -i ' + vid_path + ' -vf vfrdet -an -f null -'
os.system(cmd1)
cmd2 = 'ffprobe -v quiet -print_format json -show_streams ' + vid_path
os.system(cmd2)
cmd3 = 'ffmpeg -i ' + vid_path
os.system(cmd3)
all output 512
Running from a notebook,
vid = cv2.VideoCapture(vid_path)
fps = int(vid.get(cv2.CAP_PROP_FPS))
fps
this outputs 25
Running 'ffmpeg -i ' + vid_path
from terminal, I get
Metadata:
major_brand : isom
minor_version : 512
compatible_brands: isomiso2avc1mp41
encoder : Lavf58.45.100
Duration: 01:28:38.07, start: 0.000000, bitrate: 2241 kb/s
Stream #0:0(und): Video: h264 (High) (avc1 / 0x31637661), yuv420p(tv, bt709), 1920x1080 [SAR 1:1 DAR 16:9], 2107 kb/s, 25 fps, 25 tbr, 12800 tbn, 50 tbc (default)
Metadata:
handler_name : ISO Media file produced by Google Inc. Created on: 02/23/2019.
Stream #0:1(und): Audio: aac (LC) (mp4a / 0x6134706D), 44100 Hz, stereo, fltp, 127 kb/s (default)
Metadata:
handler_name : ISO Media file produced by Google Inc. Created on: 02/23/2019.
The fact that you are receiving 25 FPS doesn't mean that all frames have equally distant timestamps, see https://superuser.com/questions/1487401/how-can-i-tell-if-a-video-has-a-variable-frame-rate. DALI uses frame timestamps to allow random access and generation of any randomly started sequence. If the frame timestamps are not equally distant the logic that decoded video stream waits infinitely for a frame that may be not in the video stream. We are aware of this limitation but it hard to tell when we would address it as there are other tasks on our ToDo list.
The link you sent suggests a non-zero VFR indicates a VFR stream.
I wrote $ ffmpeg -i vid_path -vf vfrdet -an -f null -
in the terminal and got:
frame=132950 fps=1022 q=-0.0 Lsize=N/A time=01:28:38.00 bitrate=N/A speed=40.9x
video:69591kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: unknown
[Parsed_vfrdet_0 @ 0x7fee4fa05240] VFR:0.000000 (0/132949)
Does this mean my issue isn't due to a VFR?
Update: I ran this with another video that failed with the VideoPipeline Again, I see VFR=0 at the bottom.
Output:
$ ffmpeg -i vid_path -vf vfrdet -an -f null -
ffmpeg version 4.3.1 Copyright (c) 2000-2020 the FFmpeg developers
built with Apple clang version 12.0.0 (clang-1200.0.32.27)
configuration: --prefix=/usr/local/Cellar/ffmpeg/4.3.1_4 --enable-shared --enable-pthreads --enable-version3 --enable-avresample --cc=clang --host-cflags= --host-ldflags= --enable-ffplay --enable-gnutls --enable-gpl --enable-libaom --enable-libbluray --enable-libdav1d --enable-libmp3lame --enable-libopus --enable-librav1e --enable-librubberband --enable-libsnappy --enable-libsrt --enable-libtesseract --enable-libtheora --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxml2 --enable-libxvid --enable-lzma --enable-libfontconfig --enable-libfreetype --enable-frei0r --enable-libass --enable-libopencore-amrnb --enable-libopencore-amrwb --enable-libopenjpeg --enable-librtmp --enable-libspeex --enable-libsoxr --enable-videotoolbox --disable-libjack --disable-indev=jack
libavutil 56. 51.100 / 56. 51.100
libavcodec 58. 91.100 / 58. 91.100
libavformat 58. 45.100 / 58. 45.100
libavdevice 58. 10.100 / 58. 10.100
libavfilter 7. 85.100 / 7. 85.100
libavresample 4. 0. 0 / 4. 0. 0
libswscale 5. 7.100 / 5. 7.100
libswresample 3. 7.100 / 3. 7.100
libpostproc 55. 7.100 / 55. 7.100
Input #0, mov,mp4,m4a,3gp,3g2,mj2, from '001 All England Open - QF - LIN Dan (CHN) vs LEE Chong Wei (MAS).mp4':
Metadata:
major_brand : isom
minor_version : 512
compatible_brands: isomiso2avc1mp41
encoder : Lavf58.45.100
Duration: 01:01:00.86, start: 0.000000, bitrate: 1721 kb/s
Stream #0:0(und): Video: h264 (High) (avc1 / 0x31637661), yuv420p(tv, bt709), 1920x1080 [SAR 1:1 DAR 16:9], 1589 kb/s, 25 fps, 25 tbr, 90k tbn, 50 tbc (default)
Metadata:
handler_name : VideoHandler
Stream #0:1(und): Audio: aac (LC) (mp4a / 0x6134706D), 44100 Hz, stereo, fltp, 125 kb/s (default)
Metadata:
handler_name : SoundHandler
Stream mapping:
Stream #0:0 -> #0:0 (h264 (native) -> wrapped_avframe (native))
Press [q] to stop, [?] for help
Output #0, null, to 'pipe:':
Metadata:
major_brand : isom
minor_version : 512
compatible_brands: isomiso2avc1mp41
encoder : Lavf58.45.100
Stream #0:0(und): Video: wrapped_avframe, yuv420p, 1920x1080 [SAR 1:1 DAR 16:9], q=2-31, 200 kb/s, 25 fps, 25 tbn, 25 tbc (default)
Metadata:
handler_name : VideoHandler
encoder : Lavc58.91.100 wrapped_avframe
frame=91520 fps=892 q=-0.0 Lsize=N/A time=01:01:00.80 bitrate=N/A speed=35.7x
video:47905kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: unknown
[Parsed_vfrdet_0 @ 0x7f8d5ed60140] VFR:0.000000 (0/91519)
This suggests that your video file does not have a variable frame rate, so the issue should be somewhere else.
Yes, it seems to just stop without throwing an error. I tried running the pipeline without the model, and the same thing happens. It doesn't even give me an error - my notebook just freezes and I'm forced to restart the kernel.
Here's the code - simplified to just using the DALI iterator:
Note: print_eta
is a custom function just to check progress
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import nvidia.dali.ops as ops
import nvidia.dali.types as types
batch_size = 1
sequence_length = 50
initial_prefetch_size = 50
class VideoPipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data, shuffle):
super(VideoPipeline, self).__init__(batch_size, num_threads, device_id, seed=0)
self.input = ops.VideoReader(device='gpu', filenames=data, sequence_length=sequence_length,
shard_id=0, num_shards=1,
random_shuffle=shuffle, initial_fill=initial_prefetch_size)
self.resize = ops.Resize(size=(720,720), device='gpu')
# Apply a grayscale matrix linear transformation to the innermost dimension of your data - the channels
self.grayscale = ops.CoordTransform(M=[0.35, 0.5, 0.15]*3, dtype=types.UINT8, device='gpu')
self.normalize = ops.Normalize(mean=0.5, stddev=0.5, device='gpu')
def define_graph(self):
frames = self.input(name='Reader')
frames = self.resize(frames)
frames = self.grayscale(frames)
frames = self.normalize(frames)
return frames
video_pipe = VideoPipeline(batch_size=batch_size, num_threads=4, device_id=0, data=vid_path, shuffle=False)
pipes = [video_pipe]
pipes[0].build()
dali_iterator = DALIGenericIterator(pipes, ['data'], pipes[0].epoch_size('Reader'))
def print_eta(start, end, total_process_time, cycle_num, total_num_cycles, frame_num, num_crop_frames, cycle_type):
'''
Prints estimated time left for an algorithm to complete
Note: Meant to be used in a loop
Params:
start: Start time of algorithm
end: End time of last cycle
total_process_time: Total time elapsed
cycle_num: Cycle number
total_num_cycles: Total number of cycles
frame_num: Frame number
num_crop_frames: Number of frames to crop
cycle_type: Type of cycle (Ex. batch, frame, sequence, ...)
Returns:
end: End of current cycle
total_process_time: New total time elapsed
'''
cycle_process_time = time.time() - end
total_process_time += cycle_process_time
average_cycle_time = total_process_time / cycle_num
num_cycles_left = total_num_cycles - cycle_num
eta = average_cycle_time * num_cycles_left
eta_h = int(eta // 60 // 60)
eta_m = int((eta - eta_h*60*60) // 60)
eta_s = int((eta - eta_m*60) % 60)
end = time.time()
time_h = int((end-start) // 60 // 60)
time_m = int(((end-start) - time_h*60*60) // 60)
time_s = round(((end-start) - time_m*60) % 60)
if frame_num != 0 and total_num_cycles != 0:
sys.stdout.write('\rProcessing {}: {:.2f}% {}/{} Frames Saved: {:.2f}% {}/{} Time Elapsed: {} h {} min {} s ETA: {} h {} m {} s'
.format(cycle_type,
cycle_num*100/total_num_cycles, cycle_num, total_num_cycles,
num_crop_frames*100/frame_num, num_crop_frames, frame_num,
time_h, time_m, time_s,
eta_h, eta_m, eta_s))
sys.stdout.flush()
return end, total_process_time
def get_crop_frames(dali_iterator):
start = time.time()
end = time.time()
total_process_time = 0
total_num_sequences = dali_iterator.size
# Iterate over all batches
for batch_num, data in enumerate(dali_iterator, 1):
frame_num = sequence_length * batch_num
end, total_process_time = print_eta(start, end, total_process_time, batch_num, total_num_sequences, frame_num, num_crop_frames, cycle_type='Sequence')
get_crop_frames(dali_iterator)
For me, it hangs when the output is:
Processing Sequence: 9.04% 227/2511 Frames Saved: 0.00% 0/11350 Time Elapsed: 0 h 0 min 36 s ETA: 0 h 5 m 59 s
Here's a Google Drive link to the exact video used - please let me know if you can access it or not: https://drive.google.com/file/d/1JcOoV5aZv6SWOKFclCcqaYF5ZleBFzZE/view?usp=sharing
Maybe the issue is because of a lack of a Tensor conversion? (I have no basis for this - just something I noticed)
For example, in PyTorch the transformations would look something like:
transforms.Compose([transforms.Resize((720, 720)),
transforms.Grayscale(3),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
In my pipeline I'm running (as you can see above):
self.resize = ops.Resize(size=(720,720), device='gpu')
self.grayscale = ops.CoordTransform(M=[0.35, 0.5, 0.15]*3, dtype=types.UINT8, device='gpu')
self.normalize = ops.Normalize(mean=0.5, stddev=0.5, device='gpu')
I was able to reproduce this issue with the video sample you provided. We will investigate it and let you know. Thank you!
Thank you for looking into it! I look forward to your findings!
I'm wondering if this is a major bug or just a small one, and how long do you think it'll take to fix it?
We don't see any clear cause for this issue. We need to dive deeper. It would take at least a couple of days. We will get back to you as soon as we know more.
Hi, just a small reminder about this bug. Do you have any more information on the issue? Could I please get a status update?
We remember about this, so far we know that cuvidParseVideoData
for the faulty frame doesn't call respective pfnSequenceCallback
, pfnDecodePicture
and pfnDisplayPicture
callbacks. It is still unclear what is the reason (Video SDK problem or logic how DALI feeds it with data).
Update: I was on DALI 0.27
, and just updated to 0.30
to see if the bug was fixed unintentionally.
I tried the old video again, and run into the same issue.
Then I tried a similar video (download here: https://drive.google.com/file/d/1oJFPkNgRg69vdoaGQYR-9orwP2y9RuEu/view?usp=sharing), and it now hangs in the model, instead of in the iterator.
DALI code is same as before:
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import nvidia.dali.ops as ops
import nvidia.dali.types as types
batch_size = 1
sequence_length = 50
initial_prefetch_size = 50
class VideoPipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data, shuffle):
super(VideoPipeline, self).__init__(batch_size, num_threads, device_id, seed=0)
self.input = ops.VideoReader(device='gpu', filenames=data, sequence_length=sequence_length,
shard_id=0, num_shards=1,
random_shuffle=shuffle, initial_fill=initial_prefetch_size)
self.resize = ops.Resize(size=(720, 720), device='gpu')
# Apply a grayscale matrix linear transformation to the innermost dimension of your data - the channels
self.grayscale = ops.CoordTransform(M=[0.35, 0.5, 0.15]*3, dtype=types.UINT8, device='gpu')
self.normalize = ops.Normalize(mean=0.5, stddev=0.5, device='gpu')
def define_graph(self):
frames = self.input(name='Reader')
frames = self.resize(frames)
frames = self.grayscale(frames)
frames = self.normalize(frames)
return frames
video_pipe = VideoPipeline(batch_size=batch_size, num_threads=4, device_id=0, data=vid_path, shuffle=False)
pipes = [video_pipe]
pipes[0].build()
dali_iterator = DALIGenericIterator(pipes, ['data'], pipes[0].epoch_size('Reader'))
num_sequences = dali_iterator.size
print('num_sequences:', num_sequences)
print('sequence_length:', sequence_length)
Model architecture:
class Net2(nn.Module):
'''
Neural network for binary classification
Params:
batch_size: -
print_layers: -
'''
def __init__(self, batch_size, print_layers=False):
super(Net2, self).__init__()
stride = 2
padding = 5
image_size = 720
conv_kernel_size = 4
pool_kernel_size = 2
num_kernels = [3, 40, 80, 120, 160, 200]
output_size = 4
fc1_size, fc2_size, fc3_size, fc4_size, fc5_size, fc6_size, fc7_size = \
self._get_shapes(batch_size, image_size, output_size, conv_kernel_size, pool_kernel_size, padding, stride, num_kernels, print_layers)
# Convolutional Layers
self.conv1 = nn.Conv2d(num_kernels[0], num_kernels[1], conv_kernel_size, stride, padding)
self.conv2 = nn.Conv2d(num_kernels[1], num_kernels[2], conv_kernel_size, stride, padding)
self.conv3 = nn.Conv2d(num_kernels[2], num_kernels[3], conv_kernel_size, stride, padding)
self.conv4 = nn.Conv2d(num_kernels[3], num_kernels[4], conv_kernel_size, stride, padding)
self.conv5 = nn.Conv2d(num_kernels[4], num_kernels[5], conv_kernel_size, stride, padding)
self.max_pool = nn.MaxPool2d(pool_kernel_size, stride)
# Fully Connected Layers
self.dropout_prob = 0.3
self.fc1 = nn.Linear(fc1_size, fc2_size)
self.fc2 = nn.Linear(fc2_size, fc3_size)
self.fc3 = nn.Linear(fc3_size, fc4_size)
self.fc4 = nn.Linear(fc4_size, fc5_size)
self.fc5 = nn.Linear(fc5_size, fc6_size)
self.fc6 = nn.Linear(fc6_size, fc7_size)
self.fc7 = nn.Linear(fc7_size, output_size)
def _get_shapes(self, batch_size, image_size, output_size, conv_kernel_size, pool_kernel_size, padding, stride, num_kernels, print_layers):
'''
Calculates Convolutional and Pooling layer shapes
Params:
batch_size: -
image_size: Size of original image
output_size: Number of classes
conv_kernel_size: -
pool_kernel_size: -
padding: -
stride: -
num_kernels: -
print_layers (bool): True and layer shapes print, False otherwise
Returns:
fc1_size, fc2_size, fc3_size, fc4_size, fc5_size, fc6_size, fc7_size
'''
conv1_shape = self._conv_layer_shape(batch_size, image_size, num_kernels[1], conv_kernel_size, padding, stride)
conv2_shape = self._conv_layer_shape(batch_size, conv1_shape[2], num_kernels[2], conv_kernel_size, padding, stride)
pool2_shape = self._pool_layer_shape(batch_size, conv2_shape[2], num_kernels[2], pool_kernel_size, stride)
conv3_shape = self._conv_layer_shape(batch_size, pool2_shape[2], num_kernels[3], conv_kernel_size, padding, stride)
pool3_shape = self._pool_layer_shape(batch_size, conv3_shape[2], num_kernels[3], pool_kernel_size, stride)
conv4_shape = self._conv_layer_shape(batch_size, pool3_shape[2], num_kernels[4], conv_kernel_size, padding, stride)
pool4_shape = self._pool_layer_shape(batch_size, conv4_shape[2], num_kernels[4], pool_kernel_size, stride)
conv5_shape = self._conv_layer_shape(batch_size, pool4_shape[2], num_kernels[5], conv_kernel_size, padding, stride)
pool5_shape = self._pool_layer_shape(batch_size, conv5_shape[2], num_kernels[5], pool_kernel_size, stride)
fc1_size = pool5_shape[1] * pool5_shape[2] * pool5_shape[3]
fc2_size = fc1_size // 2
fc3_size = fc2_size // 2
fc4_size = fc3_size // 2
fc5_size = fc4_size // 2
fc6_size = fc5_size // 2
fc7_size = fc6_size // 2
if print_layers:
print('{}:\t{}'.format("conv1_shape", conv1_shape))
print('{}:\t{}'.format("conv2_shape", conv2_shape))
print('{}:\t{}'.format("pool2_shape", pool2_shape))
print('{}:\t{}'.format("conv3_shape", conv3_shape))
print('{}:\t{}'.format("pool3_shape", pool3_shape))
print('{}:\t{}'.format("conv4_shape", conv4_shape))
print('{}:\t{}'.format("pool4_shape", pool4_shape))
print('{}:\t{}'.format("conv5_shape", conv5_shape))
print('{}:\t{}'.format("pool5_shape", pool5_shape))
print('fc1_size:\t({}, {})'.format(batch_size, fc1_size))
print('fc2_size:\t({}, {})'.format(batch_size, fc2_size))
print('fc3_size:\t({}, {})'.format(batch_size, fc3_size))
print('fc4_size:\t({}, {})'.format(batch_size, fc4_size))
print('fc5_size:\t({}, {})'.format(batch_size, fc5_size))
print('fc6_size:\t({}, {})'.format(batch_size, fc6_size))
print('fc7_size:\t({}, {})'.format(batch_size, fc7_size))
print('output_size:\t({}, {})'.format(batch_size, output_size))
return fc1_size, fc2_size, fc3_size, fc4_size, fc5_size, fc6_size, fc7_size
def _conv_layer_shape(self, batch_size, w_in, num_filters, kernel_size, padding, stride):
'''
Returns shape of a convolutional layer
Parameters:
batch_size: Batch size
w_in: Width/Height of Previous Layer
num_filters: Number of Filters
kernel_size: Filter/Kernel Size
padding: Padding
stride: Stride
Returns:
shape: Shape of convolutional layer
'''
w_out = round((w_in - kernel_size + 2*padding)/stride + 1)
shape = (batch_size, num_filters, w_out, w_out)
return shape
def _pool_layer_shape(self, batch_size, w_in, num_filters, kernel_size, stride):
'''
Returns shape of a pooling layer
Params:
batch_size: Batch size
w_in: Width/Height of previous layer
num_filters: Number of filers
kernel_size: Filter/Kernel size
stride: Stride
Returns:
shape: Shape of pooling layer
'''
w_out = round((w_in * (kernel_size-1) - 1) / stride)
shape = (batch_size, num_filters, w_out, w_out)
return shape
def forward(self, x):
'''
Feed-forward
Params:
x: Batch of images
'''
# Convolutional Layers + ReLU + MaxPool
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.max_pool(x)
x = F.relu(self.conv3(x))
x = self.max_pool(x)
x = F.relu(self.conv4(x))
x = self.max_pool(x)
x = F.relu(self.conv5(x))
x = self.max_pool(x)
# Flatten
x = x.view(x.size(0), -1)
# Fully-Connected Layers + ReLU + Dropout
x = F.dropout(F.relu(self.fc1(x)), p=self.dropout_prob)
x = F.relu(self.fc2(x))
x = F.dropout(F.relu(self.fc3(x)), p=self.dropout_prob)
x = F.relu(self.fc4(x))
x = F.dropout(F.relu(self.fc5(x)), p=self.dropout_prob)
x = F.relu(self.fc6(x))
x = self.fc7(x)
# Predictions
x = F.log_softmax(x, dim=1)
return x
Using the iterator and passing data to the model:
def use_iterator(net, device, dali_iterator):
total_num_sequences = dali_iterator.size
net.eval()
# Iterate over all batches
for batch_num, data in enumerate(dali_iterator, 1):
data = torch.squeeze(data[0]['data'], 0) # (1, 8, 720, 720, 1) --> (8, 720, 720, 1)
print(data.shape)
data = data.transpose(1, 3) # (8, 720, 720, 1) --> (8, 1, 720, 720)
print(data.shape)
log_probs = net(data)
I tried a similar video, and get output, then it hangs:
torch.Size([50, 720, 720, 3])
torch.Size([50, 3, 720, 720])
This video doesn't have a VFR - see:
Input #0, mov,mp4,m4a,3gp,3g2,mj2, from '0011, 2018, All England Open, F, WATANABE-HIGASHINO (JPN) vs ZHENG-HUANG (CHN).mp4':
Metadata:
major_brand : isom
minor_version : 512
compatible_brands: isomiso2avc1mp41
encoder : Lavf58.45.100
Duration: 01:22:44.68, start: 0.000000, bitrate: 1927 kb/s
Stream #0:0(und): Video: h264 (High) (avc1 / 0x31637661), yuv420p(tv, bt709), 1920x1080 [SAR 1:1 DAR 16:9], 1795 kb/s, 25 fps, 25 tbr, 90k tbn, 50 tbc (default)
Metadata:
handler_name : VideoHandler
Stream #0:1(und): Audio: aac (LC) (mp4a / 0x6134706D), 44100 Hz, stereo, fltp, 125 kb/s (default)
Metadata:
handler_name : SoundHandler
Hi, It seems to be the same problem that is not directly connected to the DALI version but to the VideoSDK which is a part of the driver. I'm still waiting for the relevant team to check this out and get back to me with the feedback.
After more thorough debugging and syncing with the VideoSDK team, it looks that the issue is partially caused by the video itself and the way DALI works. DALI VideoReader allows you to read sequences of video with any step, stride, and length (almost), and if because of any reason there is a missing frame in the video, or frames have malformed headers, DALI is not able to return requested frames and it hangs waiting for them infinitely. We will think about how we can address this kind of issue, like repeat the preceding frame before the missing one or do something else that would make the output consistent.
Since these videos have a CFR, I suppose they have malformed headers which is causing the issue? Sounds good. Thank you for your work, and please let me know when there's an update to the issue!
Since these videos have a CFR, I suppose they have malformed headers which is causing the issue?
I think this is rather a property of the Video format that not all frames have a header with all the necessary information allowing the decoder to decode directly from it. In this case, decoder needs to skip them until the frame with necessary information is encountered,
Do you mean because it's in .mp4 format? If so, is there a format that would work better with DALI? (like a format that requires all frames to have a header with the necessary information)
mp4 is a container, the video itself is encoded as MPEG. Regarding the format, you may try out H264?
I'm trying to make a Video Pipeline with transformations for resize, grayscale, and normalize
Following Simple Video pipeline reading from multiple files and Using DALI in PyTorch from the docs, here's what I have so far:
I'm getting error:
How do I fix it? Please let me know if you need any other information.