milesial / Pytorch-UNet

PyTorch implementation of the U-Net for image semantic segmentation with high quality images
GNU General Public License v3.0
8.65k stars 2.4k forks source link

How to do the c++ implement #448

Open Li-Yidong opened 1 year ago

Li-Yidong commented 1 year ago

Thanks for your sharing. I successfully trained model on my on dataset, I want to inference it to c++. I wrote a code as below:

UnetDetector::OutputVal UnetDetector::detect(cv::Mat src)
{
    torch::jit::script::Module module;
    try {
        module = torch::jit::load("D:\\work_file\\Libtorch_test\\Libtorch_test\\x64\\Release\\checkpoint_epoch5.pth");
    }
    catch (const c10::Error& e) {
        std::cerr << "Error loading the model: " << e.msg() << std::endl;
    }
    cv::Mat img, img_rgb;
    //resize
    cv::Size new_size(src.cols / 2, src.rows / 2);
    cv::resize(src, img, new_size);

    //cv::Mat -> at::Tensor
    cv::cvtColor(img, img_rgb, cv::COLOR_BGR2RGB);

    at::TensorOptions options(at::kFloat);
    options = options.device(torch::kCPU);

    at::Tensor tensor_image = torch::from_blob(img_rgb.data, { 1, img_rgb.rows, img_rgb.cols, 3 }, options);

    at::Tensor output = module.forward({ tensor_image }).toTensor();

    std::vector<uint8_t> data(output.numel());
    memcpy(data.data(), output.data_ptr<uint8_t>(), output.numel() * sizeof(uint8_t));

    cv::Mat image(output.size(2), output.size(3), CV_8UC1, data.data());

    cv::Mat output_img;
    cv::resize(image, output_img, src.size());

    UnetDetector::OutputVal result;
    result.out_img = output_img;
}

and got error as below: Error loading the model: PytorchStreamReader failed locating file constants.pkl: file not found

What should I do?