anibali / tvl

Torch video loading library with support for GPU decoding
Apache License 2.0
18 stars 8 forks source link

NVDec backend high GPU utilisation #6

Open Multihuntr opened 5 years ago

Multihuntr commented 5 years ago

It seems that the NVDec backend has much higher GPU utilisation than the NVVL backend.

I ran this benchmarking script that @anibali gave me as reference to begin investigating why the NVDec backend seemed slower for me.

import torch
import tvl
from tvl_backends.nvvl import NvvlBackendFactory
from tvl_backends.nvdec import NvdecBackendFactory
from tvl import transforms
import time

n_frames = 100
video_file = './video00c5c3.mp4'

def basic():
  tvl.set_backend_factory('cuda', (NvdecBackendFactory)())
  vl = tvl.VideoLoader(video_file, 'cuda:0')
  list(vl.select_frames([0])) # init
  a = time.time()
  for i in range(n_frames):
    f = list(vl.select_frames([i, i+3, i+6, i+9, i+12], 4))
  b = time.time()
  print('tvl-basic', n_frames/(b-a))

def downsized():
  tvl.set_backend_factory('cuda', NvdecBackendFactory())
  vl = tvl.VideoLoader(video_file, 'cuda:0')
  list(vl.select_frames([0])) # init
  a = time.time()
  for i in range(n_frames):
    f = list(vl.select_frames([i, i+3, i+6, i+9, i+12], 4))
    f = torch.stack(f, 0).unsqueeze(0)
    f = transforms.resize(f, (1080, 1920), 'nearest')
  b = time.time()
  print('tvl-downsampled', n_frames/(b-a))

def outdim():
  tvl.set_backend_factory('cuda', NvdecBackendFactory())
  vl = tvl.VideoLoader(video_file, 'cuda:0',
                       backend_opts={'resize': (1080, 1920)})
  list(vl.select_frames([0])) # init
  a = time.time()
  for i in range(n_frames):
    f = list(vl.select_frames([i, i+3, i+6, i+9, i+12], 4))
  b = time.time()
  print('tvl-outdim', n_frames/(b-a))

def nvvl():
  tvl.set_backend_factory('cuda', NvvlBackendFactory())
  vl = tvl.VideoLoader(video_file, 'cuda:0')
  list(vl.select_frames([0])) # init
  a = time.time()
  for i in range(n_frames):
    f = list(vl.select_frames([i, i+3, i+6, i+9, i+12], 4))
  b = time.time()
  print('tvl-nvvl', n_frames/(b-a))

def nvvlscale():
  tvl.set_backend_factory('cuda', NvvlBackendFactory())
  vl = tvl.VideoLoader(video_file, 'cuda:0',
                       backend_opts={'scale': 0.5})
  list(vl.select_frames([0])) # init
  a = time.time()
  for i in range(n_frames):
    f = list(vl.select_frames([i, i+3, i+6, i+9, i+12], 4))
  b = time.time()
  print('tvl-nvvlscale', n_frames/(b-a))

nvvlscale()
nvvl()
basic()
downsized()
outdim()

From observing nvidia-smi while running the benchmark and noting which section it is up to, I observed the following GPU utilisation:

tvl-nvvlscale:       3% GPU utilisation
tvl-nvvl:        10-11% GPU utilisation 
tvl-basic:       36-49% GPU utilisation
tvl-downsampled: 41-53% GPU utilisation
tvl-outdim:      12-16% GPU utilisation

It seems that downsampling inside the CUDA kernel reduces the GPU utilisation dramatically, but even so, the NVVL backend is several factors less GPU utilisation than the NVDec backend.

This excess utilisation gets in the way of the main purpose of this library: for feeding the data to a pytorch model on the GPU.

anibali commented 5 years ago

I can confirm that the performance hit is caused by the NV12 -> RGB conversion, currently written using Torch tensor operations. So we can either write our own CUDA kernel like NVVL (note: their implementation is a bit buggy and causes a green line down the side of the image), or try to optimise the Torch conversion code.

anibali commented 5 years ago

I had a crack at optimising the NV12 to RGB conversion in https://github.com/anibali/tvl/commit/5ce5989deb155773329c3d0fc3d24d11dc9fea2a. In terms of wall-clock running time, this was a considerable improvement over the previous version. I think that GPU utilisation is a bit of a flawed metric for this kind of thing, since all that represents is "Percent of time over the past second during which one or more kernels was executing on the GPU". So, as far as I can tell, one active kernel doing minimal work continuously is enough to peg the utilisation at 100%. It is interesting that the backends are different with respect to that metric, but I don't think it's a solid enough metric to base decisions on or to optimise for.

Multihuntr commented 5 years ago

I am reasonably certain that this is having a very large impact on training time (approx 2x training time). I come to this conclusion by two specific tests that I've done, plus all my attempts to disprove it have failed.

Attempts to disprove

Code

I started using cProfile to find which parts of the code are taking a long time. It pointed me towards the model code, and I initially thought it was just because pytorch 1.0 is slower, but as I dug into it, it pointed me to some really innocuous-looking code.

It first pointed me to instances in the code where we were using cuda tensors to index into cuda tensors (I had thought it would make it faster, but of course, the tensor meta-data is still on the CPU, so it's better to keep everything used for indexing on the CPU if you can; I've now fixed this in on my branch).

Then it pointed me towards multi-part expressions, like:

unsummed_kl = p * (torch.log(p + eps) - torch.log(q + eps))

I took a shot at optimizing those expressions (mostly caching constants on the GPU), but when I broke it down, the simple multiplication was taking a comparatively enormous amount of time.

I kept referring to the cProfile timing results on our main branch, and these parts were taking a fraction of a second over a 5 minute period, but while using tvl they were taking about 25s in a 2 minute period.

Environment

The environment is a bit awkward because we were using CUDA 9.1 + python 3.5 + pytorch 0.4.0 with nvvl, whereas the Dockerfile for tvl uses CUDA 10 + python 3.6 (from anaconda) + pytorch 1.0.0, and those use different pytorch binary files. I played around with a few different Frankenstein-monster-Dockerfiles that combined various combinations of CUDA, python and pytorch versions. I definitely tried:

I might have tried:

For all of these, I noted the 0.5x speed. There did not seem to be any significant difference between these.

Tests

Torch's Profiler

After banging my head against cProfile for a while, I discovered pytorch has a profiler and a comment from Soumith saying that the output from cProfile was pretty useless because CUDA was all asynchronous (although I did run all the tests with CUDA_LAUNCH_BLOCKING=1, which is supposed to make it not so).

Here are the results.

Looking at the total CUDA time, it looks like there's no significant difference between the two. This implies to me that the profiler is checking how long each step took when it had access to the GPU, but it didn't always have access to the GPU.

The obvious one that I should have started with

I made a mock data loader that always returns the same random example (the frames are generated with torch.rand, the targets are copied from a real example). With this change, the slowdown disappears, and everything runs as expected.

For completeness:

import torch
import numpy as np

fake_targets_data = np.array(...) # snipped
fake_targets_mask = np.array(...) # snipped

fake_example = {
  'targets': np.ma.MaskedArray(fake_targets_data, fake_targets_mask),
  'frames': torch.rand(2, 5, 3, 1080, 1920).cuda(),
}

class FakeLoader:
  def __init__(*args, **kwargs):
    pass
  def __iter__(*args, **kwargs):
    while True:
      yield fake_example
anibali commented 5 years ago

Looking at the total CUDA time, it looks like there's no significant difference between the two. This implies to me that the profiler is checking how long each step took when it had access to the GPU, but it didn't always have access to the GPU.

@Sibras is it possible that the NVDEC backend is preventing PyTorch from accessing the GPU by effectively "locking" it? Maybe due to fast and loose usage of cuCtxPushCurrent/cuCtxPopCurrent, for example?

Sibras commented 5 years ago

Yeah pushing and popping contexts can definitely cause some issues and will create synchronization stalls. As to whether they are causing these issues is another question but is still a likely culprit.

anibali commented 5 years ago

If you could keep an eye out for this as you restructure the code, that would be great. I think that this currently poses a significant problem for Brandon and Ash.

bhack commented 5 years ago

I think you can check also how Dali it is handling this :wink: https://github.com/NVIDIA/DALI