MarcoForte / FBA_Matting

Official repository for the paper F, B, Alpha Matting
MIT License
464 stars 95 forks source link

Convert to onnx #54

Open ghost opened 1 year ago

ghost commented 1 year ago

Hello, how to convert pretrained model to onnx?

99991 commented 1 week ago

Here is an example which exports and loads FBA-Net using ONNX. Before running, download FBA.pth and place it in the same directory. I have only tested this with a single image. No guarantees that this works with images of different sizes.

It would also be possible to replace the call to OpenCV's distance field function with pure PyTorch to get rid of the OpenCV dependency, but that is a bit of work, so I did not do it for now.

import numpy as np
import os, cv2, math, urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyGroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5):
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        n, c, h, w = x.shape
        g = self.num_groups
        x = x.view(n, g, -1)
        var = x.var(dim=-1, keepdim=True, correction=0)
        x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(var + self.eps)
        x = x.view(n, c, -1) * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)
        return x.view(n, c, h, w)

class MyAdaptiveAvgPool2d(nn.Module):
    def __init__(self, output_size):
        super().__init__()
        self.output_size = output_size

    def forward(self, batch):
        size = self.output_size
        n, c, h, w = batch.shape
        output = torch.zeros((n, c, size, size), device=batch.device)
        for y in range(size):
            for x in range(size):
                x0 = math.floor(x * w / size)
                y0 = math.floor(y * h / size)
                x1 = math.ceil((x + 1) * w / size)
                y1 = math.ceil((y + 1) * h / size)
                output[:, :, y, x] = batch[:, :, y0:y1, x0:x1].mean(dim=(2, 3))
        return output

def norm(dim):
    return nn.GroupNorm(32, dim)
    # use this if your ONNX implementation does not support GroupNorm
    #return MyGroupNorm(32, dim)

class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__(in_channels, out_channels, **kwargs)

    def normalize_weight(self):
        weight = F.batch_norm(
            self.weight.view(1, self.out_channels, -1),
            None,
            None,
            training=True,
            momentum=0.0,
        ).reshape_as(self.weight)
        self.weight.data = weight

    def forward(self, x):
        if self.training:
            self.normalize_weight()

        return super().forward(x)

    def train(self, mode: bool = True):
        super().train(mode=mode)
        self.normalize_weight()

def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1, bias=False):
    return Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        dilation=dilation,
        padding=padding,
        bias=bias,
    )

def conv1x1(in_planes, out_planes, stride=1, bias=False):
    return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)

def dt(a):
    return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)

def trimap_transform(trimap, L=320):
    clicks = []
    for k in range(2):
        dt_mask = -dt(1 - trimap[:, :, k]) ** 2
        clicks.append(np.exp(dt_mask / (2 * ((0.02 * L) ** 2))))
        clicks.append(np.exp(dt_mask / (2 * ((0.08 * L) ** 2))))
        clicks.append(np.exp(dt_mask / (2 * ((0.16 * L) ** 2))))
    clicks = np.array(clicks)
    return clicks

def normalise_image(image):
    # Warning: Values are for RGB, but OpenCV loads images as BGR
    mean = torch.tensor([0.485, 0.456, 0.406], device=image.device).reshape(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=image.device).reshape(1, 3, 1, 1)
    return (image - mean) / std

def pyramid_pooling_module(scale):
    return nn.Sequential(
        MyAdaptiveAvgPool2d(scale),
        conv1x1(2048, 256, bias=True),
        norm(256),
        nn.LeakyReLU(),
    )

def resize(x, **kwargs):
    return nn.functional.interpolate(x, mode="bilinear", align_corners=False, **kwargs)

class Bottleneck(nn.Module):
    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        padding=1,
        dilation=1,
        expansion=4,
        downsample=None,
    ):
        super().__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = norm(planes)
        self.conv2 = conv3x3(planes, planes, stride, padding, dilation)
        self.bn2 = norm(planes)
        self.conv3 = conv1x1(planes, planes * expansion)
        self.bn3 = norm(planes * expansion)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        out = x
        out = self.conv1(out)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            x = self.downsample(x)
        out += x
        out = F.relu(out)
        return out

class ResnetDilated(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = Conv2d(
            in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = norm(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = nn.Sequential(
            Bottleneck(64, 64, stride=1, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(64, 256), norm(256))),
            Bottleneck(256, 64, stride=1, padding=1, dilation=1),
            Bottleneck(256, 64, stride=1, padding=1, dilation=1),
        )
        self.layer2 = nn.Sequential(
            Bottleneck(256, 128, stride=2, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(256, 512, stride=2), norm(512)),),
            Bottleneck(512, 128, stride=1, padding=1, dilation=1),
            Bottleneck(512, 128, stride=1, padding=1, dilation=1),
            Bottleneck(512, 128, stride=1, padding=1, dilation=1),
        )
        self.layer3 = nn.Sequential(
            Bottleneck( 512, 256, stride=1, padding=1, dilation=1, downsample=nn.Sequential(conv1x1(512, 1024), norm(1024))),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
            Bottleneck(1024, 256, stride=1, padding=2, dilation=2),
        )
        self.layer4 = nn.Sequential(
            Bottleneck(1024, 512, stride=1, padding=2, dilation=2, downsample=nn.Sequential(conv1x1(1024, 2048), norm(2048))),
            Bottleneck(2048, 512, stride=1, padding=4, dilation=4),
            Bottleneck(2048, 512, stride=1, padding=4, dilation=4),
        )

    def forward(self, x):
        conv_out = [x]
        x = F.relu(self.bn1(self.conv1(x)))
        conv_out.append(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        conv_out.append(x)
        x = self.layer2(x)
        conv_out.append(x)
        x = self.layer3(x)
        conv_out.append(x)
        x = self.layer4(x)
        conv_out.append(x)
        return conv_out

class fba_decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.ppm = nn.ModuleList(
            [
                pyramid_pooling_module(scale=1),
                pyramid_pooling_module(scale=2),
                pyramid_pooling_module(scale=3),
                pyramid_pooling_module(scale=6),
            ]
        )
        self.conv_up1 = nn.Sequential(
            conv3x3(2048 + len(self.ppm) * 256, 256, bias=True),
            norm(256),
            nn.LeakyReLU(),
            conv3x3(256, 256, bias=True),
            norm(256),
            nn.LeakyReLU(),
        )
        self.conv_up2 = nn.Sequential(
            conv3x3(256 + 256, 256, bias=True), norm(256), nn.LeakyReLU()
        )
        self.conv_up3 = nn.Sequential(
            conv3x3(256 + 64, 64, bias=True), norm(64), nn.LeakyReLU()
        )
        self.conv_up4 = nn.Sequential(
            nn.Conv2d(64 + 3 + 3 + 2, 32, 3, 1, 1, bias=True),
            nn.LeakyReLU(),
            nn.Conv2d(32, 16, 3, 1, 1, bias=True),
            nn.LeakyReLU(),
            nn.Conv2d(16, 7, 1, bias=True),
        )

    def forward(self, conv_out, img, two_chan_trimap):
        conv5 = conv_out[-1]
        ppm_out = [conv5]
        for ppm in self.ppm:
            small_conv5 = ppm(conv5)
            large_conv5 = resize(small_conv5, size=conv5.shape[2:])
            ppm_out.append(large_conv5)
        x = torch.cat(ppm_out, 1)
        x = self.conv_up1(x)
        x = resize(x, scale_factor=2)
        x = torch.cat((x, conv_out[-4]), 1)
        x = self.conv_up2(x)
        x = resize(x, scale_factor=2)
        x = torch.cat((x, conv_out[-5]), 1)
        x = self.conv_up3(x)
        x = resize(x, scale_factor=2)
        x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
        x = self.conv_up4(x)
        alpha = torch.clamp(x[:, 0][:, None], 0, 1)
        F = torch.sigmoid(x[:, 1:4])
        B = torch.sigmoid(x[:, 4:7])
        F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
        B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
        F = torch.clamp(F, 0, 1)
        B = torch.clamp(B, 0, 1)
        la = 0.1
        alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
            torch.sum((F - B) * (F - B), 1, keepdim=True) + la
        )
        alpha = torch.clamp(alpha, 0, 1)
        return torch.cat((alpha, F, B), 1)

class MattingModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ResnetDilated(in_channels=11)
        self.decoder = fba_decoder()

    def forward(self, image, two_chan_trimap, trimap_transformed):
        image_n = normalise_image(image)
        resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
        conv_out = self.encoder(resnet_input)
        return self.decoder(conv_out, image, two_chan_trimap)

def test():
    urls = """
    https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/images/troll.png
    https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/predictions/troll_alpha.png
    https://raw.githubusercontent.com/MarcoForte/FBA_Matting/master/examples/trimaps/troll.png
    """
    for url in urls.strip().split():
        _, filename = url.split("/master/")
        if not os.path.isfile(filename):
            print("Downloading", url)
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with urllib.request.urlopen(url) as r:
                data = r.read()
            with open(filename, "wb") as f:
                f.write(data)

    if not os.path.isfile("FBA.pth"):
        print("Download the model file from https://github.com/MarcoForte/FBA_Matting?tab=readme-ov-file#models and save it as FBA.pth in the current directory")
        return

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = MattingModule()
    model.load_state_dict(torch.load("FBA.pth", map_location=device), strict=True)
    model.to(device)
    model.train(False)

    image_np = cv2.imread("examples/images/troll.png")[:, :, ::-1] / 255.0
    trimap_np = cv2.imread("examples/trimaps/troll.png", cv2.IMREAD_GRAYSCALE) / 255.0
    trimap_np = np.stack([trimap_np == 0, trimap_np == 1], axis=2).astype(np.float32)
    h, w = trimap_np.shape[:2]

    h8 = int(np.ceil(h / 8) * 8)
    w8 = int(np.ceil(w / 8) * 8)
    image_scale_np = cv2.resize(image_np, (w8, h8), interpolation=cv2.INTER_LANCZOS4)
    trimap_scale_np = cv2.resize(trimap_np, (w8, h8), interpolation=cv2.INTER_LANCZOS4)

    with torch.no_grad():
        image = torch.from_numpy(image_scale_np).permute(2, 0, 1)[None, :, :, :].float().to(device)
        trimap = torch.from_numpy(trimap_scale_np).permute(2, 0, 1)[None, :, :, :].float().to(device)
        trimap_transformed = torch.from_numpy(trimap_transform(trimap_scale_np))[None, :, :, :].float().to(device)

        if 0:
            # using PyTorch
            output = model(image, trimap, trimap_transformed)
            output = output[0].cpu().numpy().transpose(1, 2, 0)
        else:
            # using onnx
            args = (image, trimap, trimap_transformed)

            torch.onnx.export(model, args, "model.onnx", verbose=True)

            import onnxruntime

            sess = onnxruntime.InferenceSession("model.onnx")

            input_feed = {inp.name: arg.detach().cpu().numpy()
                for inp, arg in zip(sess.get_inputs(), args)}

            output = sess.run(None, input_feed)[0]

            output = output[0].transpose(1, 2, 0)

    output = cv2.resize(output, (w, h), cv2.INTER_LANCZOS4)

    alpha = output[:, :, 0]
    fg = output[:, :, 1:4]
    bg = output[:, :, 4:7]

    alpha[trimap_np[:, :, 0] == 1] = 0
    alpha[trimap_np[:, :, 1] == 1] = 1
    fg[alpha == 1] = image_np[alpha == 1]
    bg[alpha == 0] = image_np[alpha == 0]

    alpha_expected = cv2.imread("examples/predictions/troll_alpha.png", cv2.IMREAD_GRAYSCALE) / 255.0

    mse = np.mean(np.square(alpha - alpha_expected))

    print(f"MSE: {mse:.20f}")

    assert mse < 1e-6, f"Error too large. I blame the developers of some dependency."

    print("Test passed :)")

if __name__ == "__main__":
    test()