scanner-research / scanner

Efficient video analysis at scale
https://scanner-research.github.io/
Apache License 2.0
615 stars 108 forks source link

Types on load are not associated with the correct types when there are multiple outputs #279

Open jhong93 opened 4 years ago

jhong93 commented 4 years ago
import sys
import cv2
import math
import json
import subprocess
from collections import namedtuple

import scannerpy as sp
from scannerpy import FrameType
from scannerpy.types import BboxList
import scannertools.face_detection
import scannertools.face_embedding
import scannertools.vis

from scannertools.face_embedding import FacenetEmbeddings

VideoInfo = namedtuple('VideoInfo', ['path', 'num_frames', 'duration'])

def get_video_info(video_path):
    command = [
        'ffprobe',
        '-loglevel',  'quiet',
        '-print_format', 'json',
        '-show_format',
        '-show_streams',
        video_path
    ]
    pipe = subprocess.Popen(command, stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT)
    out, err = pipe.communicate()
    obj = json.loads(out.decode())
    num_frames = int(obj['streams'][0]['nb_frames'])
    duration = float(obj['format']['duration'])
    return VideoInfo(video_path, num_frames, duration)

def main(video_path):
    video_info = get_video_info(video_path)
    fps = video_info.num_frames / video_info.duration
    stride = math.ceil(3 * fps)

    cl = sp.Client()
    video = sp.NamedVideoStream(cl, 'example', path=video_path)
    frames = cl.io.Input([video])
    strided_frames = cl.streams.Stride(frames, [stride])

    faces = cl.ops.MTCNNDetectFaces(frame=strided_frames)
    embeddings = cl.ops.EmbedFaces(frame=strided_frames, bboxes=faces)

    output_faces = sp.NamedStream(cl, 'face_bboxes')
    output_embeddings = sp.NamedStream(cl, 'face_embeddings')

    output_op1 = cl.io.Output(faces, [output_faces])
    output_op2 = cl.io.Output(embeddings, [output_embeddings])

    cl.run([output_op1, output_op2], sp.PerfParams.estimate(pipeline_instances_per_node=1),
           cache_mode=sp.CacheMode.Overwrite)

    detected_faces = list(output_faces.load())
    assert len(detected_faces) == math.ceil(video_info.num_frames / stride)
    print(len(list(output_embeddings.load()))) # Crashes here

if __name__ == "__main__":
    main(sys.argv[1])