MarcoForte / FBA_Matting

Official repository for the paper F, B, Alpha Matting
MIT License
464 stars 95 forks source link

resnet_bn backbone #53

Closed simplew2011 closed 1 year ago

simplew2011 commented 1 year ago

can you provide the pretrained weight base on resnet_bn backbone, the current resnet GN WS contains nn.GroupNorm operator,making it difficult to deploy

99991 commented 1 year ago

I am not sure whether you are allowed to deploy the model https://github.com/MarcoForte/FBA_Matting#models

These models have been trained on Adobe Image Matting Dataset. They are covered by the Adobe Deep Image Matting Dataset License Agreement so they can only be used and distributed for noncommercial purposes.

However, if you are allowed to deploy the model, then you do not need different pretrained weights. You can simply replace nn.GroupNorm with your own GroupNorm implementation and load the original weights.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MyGroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5):
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        n, c, h, w = x.shape
        g = self.num_groups
        x = x.view(n, g, -1)
        var = x.var(dim=-1, keepdim=True, correction=0)
        x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(var + self.eps)
        x = x.view(n, c, -1) * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)
        return x.view(n, c, h, w)

When testing this with ONNX I found that I also had to replace nn.AdaptiveAvgPool2d.

class MyAdaptiveAvgPool2d(nn.Module):
    def __init__(self, output_size):
        super().__init__()
        self.output_size = output_size

    def forward(self, batch):
        size = self.output_size
        n, c, h, w = batch.shape
        output = torch.zeros((n, c, size, size), device=batch.device)
        for y in range(size):
            for x in range(size):
                x0 = math.floor(x * w / size)
                y0 = math.floor(y * h / size)
                x1 = math.ceil((x + 1) * w / size)
                y1 = math.ceil((y + 1) * h / size)
                output[:, :, y, x] = batch[:, :, y0:y1, x0:x1].mean(dim=(2, 3))
        return output

This implementation of adaptive_avg_pool2d will only work when the shape of the batch stays the same for all invocations, which is the case for this neural network. Just be aware if you try to use it for different networks.

simplew2011 commented 1 year ago

thanks.