NVIDIA / VideoProcessingFramework

Set of Python bindings to C++ libraries which provides full HW acceleration for video decoding, encoding and GPU-accelerated color space and pixel format conversions
Apache License 2.0
1.32k stars 233 forks source link

How to properly decode multiple frames and convert them to PyTorch tensor #551

Closed darkAlert closed 1 year ago

darkAlert commented 1 year ago

My config:

I'm using the following class to decode video frames and convert them to PyTorch tensors:

class VideoDataset(Dataset):
    def __init__(self, video_path, gpu_id=0):
        self.video_path = video_path
        self.gpu_id = gpu_id
        self.decoder = None
        self.nv12_to_yuv = None
        self.yuv420_to_rgb = None
        self.rgb_to_pln = None
        self.cc_ctx = None
        self.width = None
        self.height = None

    def open(self):
        # Init HW decoder, convertor, resizer and color space context:
        self.decoder = nvc.PyNvDecoder(self.video_path, self.gpu_id)
        self.width, self.height = self.decoder.Width(), self.decoder.Height()
        self.nv12_to_yuv = nvc.PySurfaceConverter(
            self.width, self.height, nvc.PixelFormat.NV12, nvc.PixelFormat.YUV420, self.gpu_id
        )
        self.yuv420_to_rgb = nvc.PySurfaceConverter(
            self.width, self.height, nvc.PixelFormat.YUV420, nvc.PixelFormat.RGB, self.gpu_id
        )
        self.rgb_to_pln = nvc.PySurfaceConverter(
            self.width, self.height, nvc.PixelFormat.RGB, nvc.PixelFormat.RGB_PLANAR, self.gpu_id
        )
        self.cc_ctx = nvc.ColorspaceConversionContext(
            nvc.ColorSpace.BT_601, nvc.ColorRange.MPEG
        )

        return self

    def __len__(self):
        return self.decoder.Numframes()

    def __getitem__(self, idx):
        if idx >= len(self):
            raise StopIteration

        # Open video if it is not opened yet:
        if self.decoder is None:
            self.open()

        # Decode 1 compressed video frame to CUDA memory:
        nv12_surface = self.decoder.DecodeSingleSurface()
        if nv12_surface.Empty():
            print("Can not decode frame")
            return None

        # Convert from NV12 to YUV420
        # This extra step is required because not all NV12 -> RGB conversions
        # implemented in NPP support all color spaces and ranges:
        yuv420 = self.nv12_to_yuv.Execute(nv12_surface, self.cc_ctx)
        if yuv420.Empty():
            print("Can not convert nv12 -> yuv420")
            return None

        # Convert from YUV420 to interleaved RGB:
        rgb24 = self.yuv420_to_rgb.Execute(yuv420, self.cc_ctx)
        if rgb24.Empty():
            print("Can not convert yuv420 -> rgb")
            return None

        # Convert from RGB to planar RGB:
        rgb24_planar = self.rgb_to_pln.Execute(rgb24, self.cc_ctx)
        if rgb24_planar.Empty():
            print("Can not convert rgb -> rgb planar")
            return None

        if rgb24_planar.Format() != nvc.PixelFormat.RGB_PLANAR:
            raise RuntimeError("Surface shall be of RGB_PLANAR pixel format")

        surf_plane = rgb24_planar.PlanePtr()
        img_tensor = pnvc.DptrToTensor(
            surf_plane.GpuMem(),
            surf_plane.Width(),
            surf_plane.Height(),
            surf_plane.Pitch(),
            surf_plane.ElemSize(),
        )
        if img_tensor is None:
            raise RuntimeError("Can not export to tensor.")

        return img_tensor

Then I run a loop, collect two adjacent frames and merge them into a batch, that is passed to the model:

for idx, img_tensor in enumerate(dataset):
    batch.append(img_tensor)
    if len(batch) < 2:
        continue

    batch = torch.cat(batch, 0)
    batch = batch.float() / 255.0

    if torch.equal(batch[0], batch[1]):
        print('Equal', idx)

   preds = model(batch)

To check for correctness I use the torch.equal function. If I'm using the GPU, then torch.equal finds many identical frames (each run of the loop it happens on random frames). When I save such images, they are indeed the same and there are also various artifacts on them (it seems that the GPU memory is being overwritten by something).

I tried using torch.clone, but it didn't help.

But if I run it on the CPU it works fine. If I move the tensor from the GPU to the CPU before calling torch.cat, and then return it to the GPU before calling the model, then it also works correctly.

If I don't use batch mode (only 1 frame per iteration) it works fine again.

I don't use multithreading.

darkAlert commented 1 year ago

I found the answer in this thread: https://github.com/NVIDIA/VideoProcessingFramework/issues/506#issuecomment-1647804639

rgb24_planar = rgb24_planar.Clone()