Closed simplew2011 closed 1 year ago
I am not sure whether you are allowed to deploy the model https://github.com/MarcoForte/FBA_Matting#models
These models have been trained on Adobe Image Matting Dataset. They are covered by the Adobe Deep Image Matting Dataset License Agreement so they can only be used and distributed for noncommercial purposes.
However, if you are allowed to deploy the model, then you do not need different pretrained weights. You can simply replace nn.GroupNorm
with your own GroupNorm implementation and load the original weights.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MyGroupNorm(nn.Module):
def __init__(self, num_groups, num_channels, eps=1e-5):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.weight = nn.Parameter(torch.zeros(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
def forward(self, x):
n, c, h, w = x.shape
g = self.num_groups
x = x.view(n, g, -1)
var = x.var(dim=-1, keepdim=True, correction=0)
x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(var + self.eps)
x = x.view(n, c, -1) * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)
return x.view(n, c, h, w)
When testing this with ONNX I found that I also had to replace nn.AdaptiveAvgPool2d
.
class MyAdaptiveAvgPool2d(nn.Module):
def __init__(self, output_size):
super().__init__()
self.output_size = output_size
def forward(self, batch):
size = self.output_size
n, c, h, w = batch.shape
output = torch.zeros((n, c, size, size), device=batch.device)
for y in range(size):
for x in range(size):
x0 = math.floor(x * w / size)
y0 = math.floor(y * h / size)
x1 = math.ceil((x + 1) * w / size)
y1 = math.ceil((y + 1) * h / size)
output[:, :, y, x] = batch[:, :, y0:y1, x0:x1].mean(dim=(2, 3))
return output
This implementation of adaptive_avg_pool2d
will only work when the shape of the batch stays the same for all invocations, which is the case for this neural network. Just be aware if you try to use it for different networks.
thanks.
can you provide the pretrained weight base on resnet_bn backbone, the current resnet GN WS contains nn.GroupNorm operator,making it difficult to deploy