PeterL1n / RobustVideoMatting

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!
https://peterl1n.github.io/RobustVideoMatting/
GNU General Public License v3.0
8.61k stars 1.14k forks source link

C++ sample code available? #20

Open ewayboy opened 3 years ago

ewayboy commented 3 years ago

Does anyone have c++ code to run the demo?

PeterL1n commented 3 years ago

I don't have c++ sample code, but I will keep this issue open for others to answer.

BrightenWu commented 3 years ago

Python to C++.

    auto device = torch::Device("cuda");
    auto precision = torch::kFloat16;
    auto downsampleRatio = 0.4;
    c10::optional<torch::Tensor> tensorRec0;
    c10::optional<torch::Tensor> tensorRec1;
    c10::optional<torch::Tensor> tensorRec2;
    c10::optional<torch::Tensor> tensorRec3;

    auto model = torch::jit::load("rvm_mobilenetv3_fp16.torchscript");
    //! freeze error.
    //model  = torch::jit::freeze(model );
    model.to(device);

    //! imgSrc: RGB image data, such as QImage.
    auto tensorSrc = torch::from_blob(imgSrc.bits(), { imgSrc.height(),imgSrc.width(),3 }, torch::kByte);
    tensorSrc = tensorSrc.to(device);
    tensorSrc = tensorSrc.permute({ 2,0,1 }).contiguous();
    tensorSrc = tensorSrc.to(precision).div(255);
    tensorSrc.unsqueeze_(0);

    //! Inference
    auto outputs = model.forward({ tensorSrc,tensorRec0,tensorRec1,tensorRec2,tensorRec3,downsampleRatio }).toList();

    const auto &fgr = outputs.get(0).toTensor();
    const auto &pha = outputs.get(1).toTensor();
    tensorRec0 = outputs.get(2).toTensor();
    tensorRec1 = outputs.get(3).toTensor();
    tensorRec2 = outputs.get(4).toTensor();
    tensorRec3 = outputs.get(5).toTensor();

    //! Green target bgr
    auto tensorTargetBgr = torch::tensor({ 120.f / 255, 255.f / 255, 155.f / 255 }).toType(precision).to(device).view({ 1, 3, 1, 1 });
    //! Compound
    auto res_tensor = pha * fgr + (1 - pha) * tensorTargetBgr;

    res_tensor = res_tensor.mul(255).permute({ 0,2,3,1 })[0].to(torch::kU8).contiguous().cpu();
DefTruth commented 3 years ago

Does anyone have c++ code to run the demo?

@ewayboy @PeterL1n

C++ Demos for RobustVideoMatting:

rvm2021

semchan commented 3 years ago

Python to C++.

    auto device = torch::Device("cuda");
    auto precision = torch::kFloat16;
    auto downsampleRatio = 0.4;
    c10::optional<torch::Tensor> tensorRec0;
    c10::optional<torch::Tensor> tensorRec1;
    c10::optional<torch::Tensor> tensorRec2;
    c10::optional<torch::Tensor> tensorRec3;

    auto model = torch::jit::load("rvm_mobilenetv3_fp16.torchscript");
    //! freeze error.
    //model  = torch::jit::freeze(model );
    model.to(device);

    //! imgSrc: RGB image data, such as QImage.
    auto tensorSrc = torch::from_blob(imgSrc.bits(), { imgSrc.height(),imgSrc.width(),3 }, torch::kByte);
    tensorSrc = tensorSrc.to(device);
    tensorSrc = tensorSrc.permute({ 2,0,1 }).contiguous();
    tensorSrc = tensorSrc.to(precision).div(255);
    tensorSrc.unsqueeze_(0);

    //! Inference
    auto outputs = model.forward({ tensorSrc,tensorRec0,tensorRec1,tensorRec2,tensorRec3,downsampleRatio }).toList();

    const auto &fgr = outputs.get(0).toTensor();
    const auto &pha = outputs.get(1).toTensor();
    tensorRec0 = outputs.get(2).toTensor();
    tensorRec1 = outputs.get(3).toTensor();
    tensorRec2 = outputs.get(4).toTensor();
    tensorRec3 = outputs.get(5).toTensor();

    //! Green target bgr
    auto tensorTargetBgr = torch::tensor({ 120.f / 255, 255.f / 255, 155.f / 255 }).toType(precision).to(device).view({ 1, 3, 1, 1 });
    //! Compound
    auto res_tensor = pha * fgr + (1 - pha) * tensorTargetBgr;

    res_tensor = res_tensor.mul(255).permute({ 0,2,3,1 })[0].to(torch::kU8).contiguous().cpu();

The code for continuous video matting will lead to memory explosion and program crash. Is there a better way to deal with it??thanks a lot.

ewayboy commented 3 years ago

Need to release last res_tensor data? It's copy back to memory.

semchan commented 3 years ago

But I found it is not caused by “res_tensor”, it maybe caused by “tensorRec0,tensorRec2...". When "tensorRec0,1,2,3" set as global value, it will lead to memory explosion. tensorRec0 = outputs.get(2).toTensor(); tensorRec1 = outputs.get(3).toTensor(); tensorRec2 = outputs.get(4).toTensor(); tensorRec3 = outputs.get(5).toTensor();

HZNUJeffreyRen commented 2 years ago

@semchan I have the same problem. Did you solve it?

ked19 commented 1 year ago

But I found it is not caused by “res_tensor”, it maybe caused by “tensorRec0,tensorRec2...". When "tensorRec0,1,2,3" set as global value, it will lead to memory explosion. tensorRec0 = outputs.get(2).toTensor(); tensorRec1 = outputs.get(3).toTensor(); tensorRec2 = outputs.get(4).toTensor(); tensorRec3 = outputs.get(5).toTensor();

I detached the tensor and it seems solved the problem. tensorRec0 = outputs.get(2).toTensor().detach(); tensorRec1 = outputs.get(3).toTensor().detach(); tensorRec2 = outputs.get(4).toTensor().detach(); tensorRec3 = outputs.get(5).toTensor().detach();