lhwcv / mlsd_pytorch

Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"
Apache License 2.0
190 stars 37 forks source link

如何在demo_MLSD_flask中使用训练后的模型? #28

Open Code-Dataset opened 1 year ago

Code-Dataset commented 1 year ago

您好,请问训练后存在workdir/models中的模型如何在demo_MLSD_flask.py中使用呢? 我运行后报错如下: Traceback (most recent call last): File "demo_MLSD_flask.py", line 296, in init_worker(args) File "demo_MLSD_flask.py", line 255, in init_worker model = model_graph(args) File "demo_MLSD_flask.py", line 86, in init self.model = self.load(args.model_dir, args.model_type) File "demo_MLSD_flask.py", line 105, in load torch_model.load_state_dict(torch.load(model_path, map_location=device), strict=True) File "C:\Users\ai\AppData\Roaming\Python\Python38\site-packages\torch\nn\modules\module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for MobileV2_MLSD_Tiny: Unexpected key(s) in state_dict: "block17.weight", "block17.bias". size mismatch for backbone.features.0.0.weight: copying a param with shape torch.Size([32, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 4, 3, 3]).

syvince commented 10 months ago

@Code-Dataset 你解决了嘛 image

BALADA-CRAM commented 6 months ago

@Code-Dataset 你解决了嘛 image

我跟你一样的问题,请问你解决了吗

syvince commented 6 months ago

我解决了,需要改一部分代码  

vince @.***

 

------------------ 原始邮件 ------------------ 发件人: @.>; 发送时间: 2024年4月20日(星期六) 晚上8:44 收件人: @.>; 抄送: @.>; @.>; 主题: Re: [lhwcv/mlsd_pytorch] 如何在demo_MLSD_flask中使用训练后的模型? (Issue #28)

@Code-Dataset 你解决了嘛

我跟你一样的问题,请问你解决了吗

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

syvince commented 6 months ago

我解决了,最后我没采用flask运行,采用的gradio,但是原理是一样的,您可以看下mlsd_tiny.py这个文件,这个是我根据源码从新改的一份文件

vince @.***

 

------------------ 原始邮件 ------------------ 发件人: "lhwcv/mlsd_pytorch" @.>; 发送时间: 2024年4月20日(星期六) 晚上8:44 @.>; @.**@.>; 主题: Re: [lhwcv/mlsd_pytorch] 如何在demo_MLSD_flask中使用训练后的模型? (Issue #28)

@Code-Dataset 你解决了嘛

我跟你一样的问题,请问你解决了吗

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

BALADA-CRAM commented 6 months ago

我解决了,最后我没采用flask运行,采用的gradio,但是原理是一样的,您可以看下mlsd_tiny.py这个文件,这个是我根据源码从新改的一份文件 vince @.   ------------------ 原始邮件 ------------------ 发件人: "lhwcv/mlsd_pytorch" @.>; 发送时间: 2024年4月20日(星期六) 晚上8:44 @.>; @*.**@*.>; 主题: Re: [lhwcv/mlsd_pytorch] 如何在demo_MLSD_flask中使用训练后的模型? (Issue #28) @Code-Dataset 你解决了嘛 我跟你一样的问题,请问你解决了吗 — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.>

哥 您能发我邮箱吗,这里显示不出来,看不到文件 1292031079@qq.com

ljl0311 commented 5 months ago

@syvince 哥,我也一直卡在这,怎么弄得哥。在这边看不到,您也能给我也发一份吗?谢了哥。 我邮箱是:2727210959@qq.com

ganqii commented 1 month ago

我也遇到了这个问题,可以给我发一份吗,谢谢! 邮箱:1564388852@qq.com

syvince commented 4 weeks ago

mlsd_tiny.py

import os
import sys
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F

class BlockTypeA(nn.Module):
    def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale=True):
        super(BlockTypeA, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_c2, out_c2, kernel_size=1),
            nn.BatchNorm2d(out_c2),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_c1, out_c1, kernel_size=1),
            nn.BatchNorm2d(out_c1),
            nn.ReLU(inplace=True)
        )
        self.upscale = upscale

    def forward(self, a, b):
        b = self.conv1(b)
        a = self.conv2(a)
        if self.upscale:
            b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
        return torch.cat((a, b), dim=1)

class BlockTypeB(nn.Module):
    def __init__(self, in_c, out_c):
        super(BlockTypeB, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_c),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv1(x) + x
        x = self.conv2(x)
        return x

class BlockTypeC(nn.Module):
    def __init__(self, in_c, out_c):
        super(BlockTypeC, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
            nn.BatchNorm2d(in_c),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_c),
            nn.ReLU()
        )
        self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        self.channel_pad = out_planes - in_planes
        self.stride = stride
        # padding = (kernel_size - 1) // 2

        # TFLite uses slightly different padding than PyTorch
        # if stride == 2:
        #     padding = 0
        # else:
        padding = (kernel_size - 1) // 2

        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True)
        )
        self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)

    def forward(self, x):
        # # TFLite uses  different padding
        # if self.stride == 2:
        #     x = F.pad(x, (0, 1, 0, 1), "constant", 0)
        #     #print(x.shape)

        for module in self:
            if not isinstance(module, nn.MaxPool2d):
                x = module(x)
        return x

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            # dw
            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        ])
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MobileNetV2(nn.Module):
    def __init__(self, pretrained=True):
        """
        MobileNet V2 main class
        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
            block: Module specifying inverted residual building block for mobilenet
        """
        super(MobileNetV2, self).__init__()

        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        width_mult = 1.0
        round_nearest = 8

        inverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            # [6, 96, 3, 1],
            # [6, 160, 3, 2],
            # [6, 320, 1, 1],
        ]

        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                             "or a 4-element list, got {}".format(inverted_residual_setting))

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t))
                input_channel = output_channel
        self.features = nn.Sequential(*features)

        self.fpn_selected = [1, 3, 6, 10]
        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

        # if pretrained:
        #    self._load_pretrained_model()

    def _forward_impl(self, x):
        # This exists since TorchScript doesn't support inheritance, so the superclass method
        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
        fpn_features = []
        for i, f in enumerate(self.features):
            if i > self.fpn_selected[-1]:
                break
            x = f(x)
            if i in self.fpn_selected:
                fpn_features.append(x)

        c1, c2, c3, c4 = fpn_features
        return c1, c2, c3, c4

    def forward(self, x):
        return self._forward_impl(x)

    def _load_pretrained_model(self):
        pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)

class MobileV2_MLSD_Tiny(nn.Module):
    def __init__(self, with_deconv):
        super(MobileV2_MLSD_Tiny, self).__init__()

        self.backbone = MobileNetV2(pretrained=True)

        self.block12 = BlockTypeA(in_c1=32, in_c2=64,
                                  out_c1=64, out_c2=64)
        self.block13 = BlockTypeB(128, 64)

        self.block14 = BlockTypeA(in_c1=24, in_c2=64,
                                  out_c1=32, out_c2=32)
        self.block15 = BlockTypeB(64, 64)

        self.block16 = BlockTypeC(64, 16)

        self.with_deconv = with_deconv

        if self.with_deconv:
            self.block17 = BilinearConvTranspose2d(16, 2, 1)
            self.block17.reset_parameters()

    def forward(self, x):
        c1, c2, c3, c4 = self.backbone(x)

        x = self.block12(c3, c4)
        x = self.block13(x)
        x = self.block14(c2, x)
        x = self.block15(x)
        x = self.block16(x)
        # x = x[:, 7:, :, :]
        # print(x.shape)
        if self.with_deconv:
            x = self.block17(x)
        else:
            x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)

        return x

class BilinearConvTranspose2d(nn.ConvTranspose2d):
    """A conv transpose initialized to bilinear interpolation."""

    def __init__(self, channels, stride, groups=1):
        """Set up the layer.

        Parameters
        ----------
        channels: int
            The number of input and output channels

        stride: int or tuple
            The amount of upsampling to do

        groups: int
            Set to 1 for a standard convolution. Set equal to channels to
            make sure there is no cross-talk between channels.
        """
        if isinstance(stride, int):
            stride = (stride, stride)

        assert groups in (1, channels), "Must use no grouping, " + \
                                        "or one group per channel"

        kernel_size = (2 * stride[0] - 1, 2 * stride[1] - 1)
        padding = (stride[0] - 1, stride[1] - 1)
        super().__init__(
            channels, channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=padding,
            groups=groups)

    def reset_parameters(self):
        """Reset the weight and bias."""
        nn.init.constant_(self.bias, 0)
        nn.init.constant_(self.weight, 0)
        bilinear_kernel = self.bilinear_kernel(self.stride)
        for i in range(self.in_channels):
            if self.groups == 1:
                j = i
            else:
                j = 0
            self.weight.data[i, j] = bilinear_kernel

    @staticmethod
    def bilinear_kernel(stride):
        """Generate a bilinear upsampling kernel."""
        num_dims = len(stride)

        shape = (1,) * num_dims
        bilinear_kernel = torch.ones(*shape)

        # The bilinear kernel is separable in its spatial dimensions
        # Build up the kernel channel by channel
        for channel in range(num_dims):
            channel_stride = stride[channel]
            kernel_size = 2 * channel_stride - 1
            # e.g. with stride = 4
            # delta = [-3, -2, -1, 0, 1, 2, 3]
            # channel_filter = [0.25, 0.5, 0.75, 1.0, 0.75, 0.5, 0.25]
            delta = torch.arange(1 - channel_stride, channel_stride)
            channel_filter = (1 - torch.abs(delta / channel_stride))
            # Apply the channel filter to the current channel
            shape = [1] * num_dims
            shape[channel] = kernel_size
            bilinear_kernel = bilinear_kernel * channel_filter.view(shape)
        return bilinear_kernel
syvince commented 4 weeks ago

web.py

import cv2
import os

import numpy as np
import torch
from torch.nn import functional as F
import argparse

from mlsd_tiny import MobileV2_MLSD_Tiny
from albumentations import Normalize
import gradio as gr

from PIL import Image

def detect_weld(img, top_k, min_len, score_thresh):
    img = np.array(img)

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str,
                        default="/home/vsislab/python_work/weld_work/mlsd_pytorch/workdir/models/mobilev2_mlsd_tiny_512_bsize24/best.pth")
    # parser.add_argument()
    parser.add_argument("--input_size", type=int, help="image input size", default=512)
    # parser.add_argument("--sap_thresh", type=float, help="sAP thresh", default=10.0)
    parser.add_argument("--top_k", type=float, help="top k lines", default=top_k)  # 500
    parser.add_argument("--min_len", type=float, help="min len of line", default=min_len)  # 0
    parser.add_argument("--score_thresh", type=float, help="line score thresh", default=score_thresh)  # 0.2
    opt = parser.parse_args()

    # img = cv2.imread()
    model = MobileV2_MLSD_Tiny(with_deconv=True).cuda().eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(torch.load(opt.model_path, map_location=device), strict=True)

    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    line_info = mlsd_model(opt, model, img)
    lines = line_info["lines"]
    for line in lines:
        x0, y0, x1, y1 = map(int, line)  # 将坐标转换为整数
        cv2.line(img, (x0, y0), (x1, y1), (0, 255, 0), 4)  # 绘制线段,颜色为绿色,线宽为2
    print(line_info)
    # img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    return img, line_info

def mlsd_model(opt, model, img):
    # print(args)
    # print(args.model_path)

    h, w, _ = img.shape
    img = cv2.resize(img, (opt.input_size, opt.input_size))
    test_aug = Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = test_aug(image=img)['image']

    img = img.transpose(2, 0, 1)
    img = torch.from_numpy(img).unsqueeze(0).float().cuda()

    with torch.no_grad():
        batch_outputs = model(img)

    tp_mask = batch_outputs[:, 7:, :, :]
    center_ptss, pred_lines, scores = deccode_lines(tp_mask, opt.score_thresh, opt.min_len, opt.top_k, 3)
    pred_lines = pred_lines.detach().cpu().numpy()
    scores = scores.detach().cpu().numpy()
    pred_lines_list = []
    scores_list = []
    for line, score in zip(pred_lines, scores):
        x0, y0, x1, y1 = line

        x0 = w * x0 / (opt.input_size / 2)
        x1 = w * x1 / (opt.input_size / 2)

        y0 = h * y0 / (opt.input_size / 2)
        y1 = h * y1 / (opt.input_size / 2)

        pred_lines_list.append([x0, y0, x1, y1])
        scores_list.append(score)

    return {
        'width': w,
        'height': h,
        'lines': pred_lines_list,
        'scores': scores_list
    }

def deccode_lines(tpMap, score_thresh, len_thresh, topk_n, ksize=3):
    '''
    tpMap:
    center: tpMap[1, 0, :, :]
    displacement: tpMap[1, 1:5, :, :]
    '''
    b, c, h, w = tpMap.shape
    assert b == 1, 'only support bsize==1'
    displacement = tpMap[:, 1:5, :, :]
    center = tpMap[:, 0, :, :]
    heat = torch.sigmoid(center)
    hmax = F.max_pool2d(heat, (ksize, ksize), stride=1, padding=(ksize - 1) // 2)
    keep = (hmax == heat).float()
    heat = heat * keep
    heat = heat.reshape(-1, )

    heat = torch.where(heat < score_thresh, torch.zeros_like(heat), heat)

    scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
    valid_inx = torch.where(scores > score_thresh)
    scores = scores[valid_inx]
    indices = indices[valid_inx]

    yy = torch.floor_divide(indices, w).unsqueeze(-1)
    xx = torch.fmod(indices, w).unsqueeze(-1)
    center_ptss = torch.cat((xx, yy), dim=-1)

    start_point = center_ptss + displacement[0, :2, yy, xx].reshape(2, -1).permute(1, 0)
    end_point = center_ptss + displacement[0, 2:, yy, xx].reshape(2, -1).permute(1, 0)

    lines = torch.cat((start_point, end_point), dim=-1)

    all_lens = (end_point - start_point) ** 2
    all_lens = all_lens.sum(dim=-1)
    all_lens = torch.sqrt(all_lens)
    valid_inx = torch.where(all_lens > len_thresh)

    center_ptss = center_ptss[valid_inx]
    lines = lines[valid_inx]
    scores = scores[valid_inx]

    return center_ptss, lines, scores

def main():
    top_k = gr.Slider(0, 1000, 500, label="top_k")
    min_len = gr.Slider(0, 10, 0, label="min_len")
    score_thresh = gr.Slider(0, 1, 0.2, label="score_thresh")
    line_info = gr.Textbox(label="lines info")
    weld_demo = gr.Interface(detect_weld, [gr.Image(type="pil"), top_k, min_len, score_thresh],
                             ["image", line_info],
                             examples=[
                                 ["example/1.jpg"],
                                 ["example/2.jpg"],
                                 ["example/3.jpg"],
                                 ["example/4.jpg"]
                             ],
                             ).launch(server_name="180.201.6.110", server_port=8860)

if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--model_path", type=str,
    #                     default="/home/vsislab/python_work/weld_work/mlsd_pytorch/workdir/models/mobilev2_mlsd_tiny_512_bsize24/best.pth")
    # # parser.add_argument()
    # parser.add_argument("--input_size", type=int, help="image input size", default=512)
    # parser.add_argument("--sap_thresh", type=float, help="sAP thresh", default=10.0)
    # parser.add_argument("--top_k", type=float, help="top k lines", default=1000)
    # parser.add_argument("--min_len", type=float, help="min len of line", default=2.0)
    # parser.add_argument("--score_thresh", type=float, help="line score thresh", default=0.1)
    # opt = parser.parse_args()
    main()
    # main(args)
syvince commented 4 weeks ago

@ganqii @ljl0311