PeterL1n / RobustVideoMatting

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!
https://peterl1n.github.io/RobustVideoMatting/
GNU General Public License v3.0
8.32k stars 1.11k forks source link

Request for Support - Real-time demo using Replicate GPU #220

Open Agusteando opened 1 year ago

Agusteando commented 1 year ago

Hello PeterL1n,

I am trying to use the "Robust Video Matting" tool on arielreplicate/robust_video_matting on Replicate and I would like to know if it is possible to use a stream for the input_video parameter. I am interested in trying out a real-time demo using Replicate GPU.

Is it possible to input a video stream as an input_video parameter in the tool and run the tool in real-time? If not, are there any workarounds or alternative methods to accomplish this? I would really appreciate your help!

Thanks in advance!

gdamms commented 1 year ago

I think you can have a good first try using opencv-python.

Once you've installed the package, you can import it with:

import cv2

You can now read streams like your webcam or videos using VideoCapture.

# To read your webcam stream.
vc = cv2.VideoCapture(-1)

# To read a video file.
vc = cv2.VideoCapture("path/to/file")

Here is a simple code snippet:

import cv2
import torch
from torchvision import transforms

from model import MattingNetwork

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DOWNSAMPLING_RATIO = 1.0

# Setup the model.
model = MattingNetwork("mobilenetv3").eval().to(DEVICE)
model.load_state_dict(torch.load("checkpoints/rvm_mobilenetv3.pth"))

# Set up the initial state.
rec = [None] * 4
bgr = torch.tensor([0.0, 1.0, 0.0], device=DEVICE).view(3, 1, 1)

# Set up the video reader.
video_reader = cv2.VideoCapture(-1)

with torch.no_grad():

    # Start the live inference.
    running = True
    while running:

        # Get the next frame.
        ret, frame = video_reader.read()
        if not ret:
            running = False
            continue

        # Infer the frame.
        frame = transforms.ToTensor()(frame)
        frame = frame.to(DEVICE, non_blocking=True).unsqueeze(0)
        fgr, pha, *rec = model(frame, *rec, DOWNSAMPLING_RATIO)

        # Compose the frame.
        com = fgr * pha + bgr * (1 - pha)
        frame = com.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]

        # Show the frame.
        cv2.imshow("Live RVM", frame)

        # Exit when "q" or "ESC" is pressed.
        key = cv2.waitKey(1)
        if key == ord("q") or key == 27:
            running = False