PeterL1n / RobustVideoMatting

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!
https://peterl1n.github.io/RobustVideoMatting/
GNU General Public License v3.0
8.61k stars 1.14k forks source link

Python to C++. #26

Open wangsnowsea2020 opened 3 years ago

wangsnowsea2020 commented 3 years ago

Python to C++.

    auto device = torch::Device("cuda");
    auto precision = torch::kFloat16;
    auto downsampleRatio = 0.4;
    c10::optional<torch::Tensor> tensorRec0;
    c10::optional<torch::Tensor> tensorRec1;
    c10::optional<torch::Tensor> tensorRec2;
    c10::optional<torch::Tensor> tensorRec3;

    auto model = torch::jit::load("rvm_mobilenetv3_fp16.torchscript");
    //! freeze error.
    //model  = torch::jit::freeze(model );
    model.to(device);

    //! imgSrc: RGB image data, such as QImage.
    auto tensorSrc = torch::from_blob(imgSrc.bits(), { imgSrc.height(),imgSrc.width(),3 }, torch::kByte);
    tensorSrc = tensorSrc.to(device);
    tensorSrc = tensorSrc.permute({ 2,0,1 }).contiguous();
    tensorSrc = tensorSrc.to(precision).div(255);
    tensorSrc.unsqueeze_(0);

    //! Inference
    auto outputs = model.forward({ tensorSrc,tensorRec0,tensorRec1,tensorRec2,tensorRec3,downsampleRatio }).toList();

    const auto &fgr = outputs.get(0).toTensor();
    const auto &pha = outputs.get(1).toTensor();
    tensorRec0 = outputs.get(2).toTensor();
    tensorRec1 = outputs.get(3).toTensor();
    tensorRec2 = outputs.get(4).toTensor();
    tensorRec3 = outputs.get(5).toTensor();

    //! Green target bgr
    auto tensorTargetBgr = torch::tensor({ 120.f / 255, 255.f / 255, 155.f / 255 }).toType(precision).to(device).view({ 1, 3, 1, 1 });
    //! Compound
    auto res_tensor = pha * fgr + (1 - pha) * tensorTargetBgr;

    res_tensor = res_tensor.mul(255).permute({ 0,2,3,1 })[0].to(torch::kU8).contiguous().cpu();

Originally posted by @BrightenWu in https://github.com/PeterL1n/RobustVideoMatting/issues/20#issuecomment-922211633 上面的代码单个图片可以处理,下面的代码,发现处理第一帧时可以正常,第二帧时奔溃了 ,请教一下各位,是什么问题? while (vCap.read(frame)) { cv::cvtColor(frame, srcframe, cv::COLOR_BGR2RGB);

    auto src = torch::from_blob(srcframe.data, { srcframe.rows,srcframe.cols,3 }, torch::kByte);
    src = src.to(device);
    src = src.permute({ 2,0,1 }).contiguous();
    src = src.to(precision).div(255);
    src.unsqueeze_(0);

    //auto outputs = model.forward({ src, tRec0,tRec1,tRec2,tRec3,downsampleRatio }).toTuple()->elements();
    auto outputs = model.forward({ src, tRec0,tRec1,tRec2,tRec3,downsampleRatio }).toList();

    const auto& fgr = outputs.get(0).toTensor();
    const auto& pha = outputs.get(1).toTensor();

    tRec0 = outputs.get(2).toTensor();
    tRec1 = outputs.get(3).toTensor();
    tRec2 = outputs.get(4).toTensor();
    tRec3 = outputs.get(5).toTensor();

     auto com =  pha *fgr +  newbgr*(1 - pha);

    cv::Mat resultImg = torchTensortoCVMat(com);

    cv::cvtColor(resultImg, resultImg, COLOR_RGB2BGR);

    cv::imshow("demo", resultImg);
    if (waitKey(1) >= 0)
        break;
}
semchan commented 3 years ago

Python to C++.

    auto device = torch::Device("cuda");
    auto precision = torch::kFloat16;
    auto downsampleRatio = 0.4;
    c10::optional<torch::Tensor> tensorRec0;
    c10::optional<torch::Tensor> tensorRec1;
    c10::optional<torch::Tensor> tensorRec2;
    c10::optional<torch::Tensor> tensorRec3;

    auto model = torch::jit::load("rvm_mobilenetv3_fp16.torchscript");
    //! freeze error.
    //model  = torch::jit::freeze(model );
    model.to(device);

    //! imgSrc: RGB image data, such as QImage.
    auto tensorSrc = torch::from_blob(imgSrc.bits(), { imgSrc.height(),imgSrc.width(),3 }, torch::kByte);
    tensorSrc = tensorSrc.to(device);
    tensorSrc = tensorSrc.permute({ 2,0,1 }).contiguous();
    tensorSrc = tensorSrc.to(precision).div(255);
    tensorSrc.unsqueeze_(0);

    //! Inference
    auto outputs = model.forward({ tensorSrc,tensorRec0,tensorRec1,tensorRec2,tensorRec3,downsampleRatio }).toList();

    const auto &fgr = outputs.get(0).toTensor();
    const auto &pha = outputs.get(1).toTensor();
    tensorRec0 = outputs.get(2).toTensor();
    tensorRec1 = outputs.get(3).toTensor();
    tensorRec2 = outputs.get(4).toTensor();
    tensorRec3 = outputs.get(5).toTensor();

    //! Green target bgr
    auto tensorTargetBgr = torch::tensor({ 120.f / 255, 255.f / 255, 155.f / 255 }).toType(precision).to(device).view({ 1, 3, 1, 1 });
    //! Compound
    auto res_tensor = pha * fgr + (1 - pha) * tensorTargetBgr;

    res_tensor = res_tensor.mul(255).permute({ 0,2,3,1 })[0].to(torch::kU8).contiguous().cpu();

Originally posted by @BrightenWu in #20 (comment) 上面的代码单个图片可以处理,下面的代码,发现处理第一帧时可以正常,第二帧时奔溃了 ,请教一下各位,是什么问题? while (vCap.read(frame)) { cv::cvtColor(frame, srcframe, cv::COLOR_BGR2RGB);

    auto src = torch::from_blob(srcframe.data, { srcframe.rows,srcframe.cols,3 }, torch::kByte);
    src = src.to(device);
    src = src.permute({ 2,0,1 }).contiguous();
    src = src.to(precision).div(255);
    src.unsqueeze_(0);

    //auto outputs = model.forward({ src, tRec0,tRec1,tRec2,tRec3,downsampleRatio }).toTuple()->elements();
    auto outputs = model.forward({ src, tRec0,tRec1,tRec2,tRec3,downsampleRatio }).toList();

    const auto& fgr = outputs.get(0).toTensor();
    const auto& pha = outputs.get(1).toTensor();

    tRec0 = outputs.get(2).toTensor();
    tRec1 = outputs.get(3).toTensor();
    tRec2 = outputs.get(4).toTensor();
    tRec3 = outputs.get(5).toTensor();

     auto com =  pha *fgr +  newbgr*(1 - pha);

    cv::Mat resultImg = torchTensortoCVMat(com);

    cv::cvtColor(resultImg, resultImg, COLOR_RGB2BGR);

    cv::imshow("demo", resultImg);
    if (waitKey(1) >= 0)
        break;
}

是不是内存爆了呢?检查一下

BrightenWu commented 2 years ago

Try com.cpu() in the end.

HZNUJeffreyRen commented 2 years ago

I have the same problem,How did you solve?

o2co2 commented 2 years ago

tRec0 = outputs.get(2).toTensor(); tRec1 = outputs.get(3).toTensor(); tRec2 = outputs.get(4).toTensor(); tRec3 = outputs.get(5).toTensor(); 这几句会导致GPU内存爆掉,为什么?

DanishFaraaz commented 2 years ago

I am processing a video and I get the same error. The code keeps looping but only the first frame is displayed in the window. Did anyone find a fix or have a better solution?