Closed luoww1992 closed 3 years ago
Hi, thanks for your attention.
For your questions: Q1: it needs 3 folders--original, trimap, matter? Yes. And the trimap can be generated from the matte.
Q2: so the size of image must be 512,? No, you can use any size to train the model.
Q2: and the image need to do other operations, like change the color and so on ? You can use the most common data augmentation to process the training data, e.g., flipping, normalization.
@luoww1992 would you be able to share your successful training approach? did you write a dataloader and a trimap creator? Any tips or code would be greatly apreciated! . Thank you!
@ luoww1992您能分享成功的培训方法吗?您是否编写了数据加载器和Trimap创建器? 任何提示或代码都将不胜感激!。谢谢!
i am making trimap dataset,it will spend some times
.......
i am organizing the code .
@ZHKKKe , i am training dataset for videoMatting , but the loss is very large, and the detail_loss is 0 ,no change. my train.py :
import math
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from functools import reduce
import cv2
import numpy as np
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import grey_dilation, grey_erosion
from scipy.ndimage import morphology
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
from math import *
__all__ = [
'supervised_training_iter',
'soc_adaptation_iter',
]
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expansion, dilation=1):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
super(MobileNetV2, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
input_channel = 32
last_channel = 1280
interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[expansion, 24, 2, 2],
[expansion, 32, 3, 2],
[expansion, 64, 4, 2],
[expansion, 96, 3, 1],
[expansion, 160, 3, 2],
[expansion, 320, 1, 1],
]
# building first layer
input_channel = _make_divisible(input_channel * alpha, 8)
self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel
self.features = [conv_bn(self.in_channels, input_channel, 2)]
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
output_channel = _make_divisible(int(c * alpha), 8)
for i in range(n):
if i == 0:
self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
else:
self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
input_channel = output_channel
# building last several layers
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it nn.Sequential
self.features = nn.Sequential(*self.features)
# building classifier
if self.num_classes is not None:
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
# Initialize weights
self._init_weights()
def forward(self, x, feature_names=None):
# Stage1
x = reduce(lambda x, n: self.features[n](x), list(range(0, 2)), x)
# Stage2
x = reduce(lambda x, n: self.features[n](x), list(range(2, 4)), x)
# Stage3
x = reduce(lambda x, n: self.features[n](x), list(range(4, 7)), x)
# Stage4
x = reduce(lambda x, n: self.features[n](x), list(range(7, 14)), x)
# Stage5
x = reduce(lambda x, n: self.features[n](x), list(range(14, 19)), x)
# Classification
if self.num_classes is not None:
x = x.mean(dim=(2, 3))
x = self.classifier(x)
# Output
return x
def _load_pretrained_model(self, pretrained_file):
pretrain_dict = torch.load(pretrained_file, map_location='cpu')
model_dict = {}
state_dict = self.state_dict()
print("[MobileNetV2] Loading pretrained model...")
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
else:
print(k, "is ignored")
state_dict.update(model_dict)
self.load_state_dict(state_dict)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
class IBNorm(nn.Module):
""" Combine Instance Norm and Batch Norm into One Layer
"""
def __init__(self, in_channels):
super(IBNorm, self).__init__()
in_channels = in_channels
self.bnorm_channels = int(in_channels / 2)
self.inorm_channels = in_channels - self.bnorm_channels
self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
def forward(self, x):
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
return torch.cat((bn_x, in_x), 1)
class Conv2dIBNormRelu(nn.Module):
""" Convolution + IBNorm + ReLu
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, bias=True,
with_ibn=True, with_relu=True):
super(Conv2dIBNormRelu, self).__init__()
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
]
if with_ibn:
layers.append(IBNorm(out_channels))
if with_relu:
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class BaseBackbone(nn.Module):
""" Superclass of Replaceable Backbone Model for Semantic Estimation
"""
def __init__(self, in_channels):
super(BaseBackbone, self).__init__()
self.in_channels = in_channels
self.model = None
self.enc_channels = []
def forward(self, x):
raise NotImplementedError
def load_pretrained_ckpt(self):
raise NotImplementedError
class MobileNetV2Backbone(BaseBackbone):
""" MobileNetV2 Backbone
"""
def __init__(self, in_channels):
super(MobileNetV2Backbone, self).__init__(in_channels)
self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
self.enc_channels = [16, 24, 32, 96, 1280]
def forward(self, x):
x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
enc2x = x
x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
enc4x = x
x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
enc8x = x
x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
enc16x = x
x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
enc32x = x
return [enc2x, enc4x, enc8x, enc16x, enc32x]
def load_pretrained_ckpt(self):
# the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
ckpt_path = './mobilenetv2_human_seg.ckpt'
if not os.path.exists(ckpt_path):
print('cannot find the pretrained mobilenetv2 backbone')
exit()
ckpt = torch.load(ckpt_path)
self.model.load_state_dict(ckpt)
class SEBlock(nn.Module):
""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
"""
def __init__(self, in_channels, out_channels, reduction=1):
super(SEBlock, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, int(in_channels // reduction), bias=False),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels // reduction), out_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
w = self.pool(x).view(b, c)
w = self.fc(w).view(b, c, 1, 1)
return x * w.expand_as(x)
class LRBranch(nn.Module):
""" Low Resolution Branch of MODNet
"""
def __init__(self, backbone):
super(LRBranch, self).__init__()
enc_channels = backbone.enc_channels
self.backbone = backbone
self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False,
with_relu=False)
def forward(self, img, inference):
enc_features = self.backbone.forward(img)
enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
enc32x = self.se_block(enc32x)
lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
lr16x = self.conv_lr16x(lr16x)
lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
lr8x = self.conv_lr8x(lr8x)
pred_semantic = None
if not inference:
lr = self.conv_lr(lr8x)
pred_semantic = torch.sigmoid(lr)
return pred_semantic, lr8x, [enc2x, enc4x]
class HRBranch(nn.Module):
""" High Resolution Branch of MODNet
"""
def __init__(self, hr_channels, enc_channels):
super(HRBranch, self).__init__()
self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
self.conv_hr4x = nn.Sequential(
Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
)
self.conv_hr2x = nn.Sequential(
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
)
self.conv_hr = nn.Sequential(
Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
)
def forward(self, img, enc2x, enc4x, lr8x, inference):
img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
enc2x = self.tohr_enc2x(enc2x)
hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
enc4x = self.tohr_enc4x(enc4x)
hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
pred_detail = None
if not inference:
hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
hr = self.conv_hr(torch.cat((hr, img), dim=1))
pred_detail = torch.sigmoid(hr)
return pred_detail, hr2x
class FusionBranch(nn.Module):
""" Fusion Branch of MODNet
"""
def __init__(self, hr_channels, enc_channels):
super(FusionBranch, self).__init__()
self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
self.conv_f = nn.Sequential(
Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
)
def forward(self, img, lr8x, hr2x):
lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
lr4x = self.conv_lr4x(lr4x)
lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
f = self.conv_f(torch.cat((f, img), dim=1))
pred_matte = torch.sigmoid(f)
return pred_matte
class MODNet(nn.Module):
""" Architecture of MODNet
"""
def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
super(MODNet, self).__init__()
self.in_channels = in_channels
self.hr_channels = hr_channels
self.backbone_arch = backbone_arch
self.backbone_pretrained = backbone_pretrained
self.backbone = MobileNetV2Backbone(self.in_channels)
self.lr_branch = LRBranch(self.backbone)
self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
for m in self.modules():
if isinstance(m, nn.Conv2d):
self._init_conv(m)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
self._init_norm(m)
if self.backbone_pretrained:
self.backbone.load_pretrained_ckpt()
def forward(self, img, inference):
pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
pred_matte = self.f_branch(img, lr8x, hr2x)
return pred_semantic, pred_detail, pred_matte
def freeze_norm(self):
norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
for m in self.modules():
for n in norm_types:
if isinstance(m, n):
m.eval()
continue
def _init_conv(self, conv):
nn.init.kaiming_uniform_(
conv.weight, a=0, mode='fan_in', nonlinearity='relu')
if conv.bias is not None:
nn.init.constant_(conv.bias, 0)
def _init_norm(self, norm):
if norm.weight is not None:
nn.init.constant_(norm.weight, 1)
nn.init.constant_(norm.bias, 0)
class GaussianBlurLayer(nn.Module):
""" Add Gaussian Blur to a 4D tensors
This layer takes a 4D tensor of {N, C, H, W} as input.
The Gaussian blur will be performed in given channel number (C) splitly.
"""
def __init__(self, channels, kernel_size):
"""
Arguments:
channels (int): Channel for input tensor
kernel_size (int): Size of the kernel used in blurring
"""
super(GaussianBlurLayer, self).__init__()
self.channels = channels
self.kernel_size = kernel_size
assert self.kernel_size % 2 != 0
self.op = nn.Sequential(
nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
nn.Conv2d(channels, channels, self.kernel_size,
stride=1, padding=0, bias=None, groups=channels)
)
self._init_kernel()
def forward(self, x):
"""
Arguments:
x (torch.Tensor): input 4D tensor
Returns:
torch.Tensor: Blurred version of the input
"""
if not len(list(x.shape)) == 4:
print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
exit()
elif not x.shape[1] == self.channels:
print('In \'GaussianBlurLayer\', the required channel ({0}) is'
'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
exit()
return self.op(x)
def _init_kernel(self):
sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
n = np.zeros((self.kernel_size, self.kernel_size))
i = math.floor(self.kernel_size / 2)
n[i, i] = 1
kernel = scipy.ndimage.gaussian_filter(n, sigma)
for name, param in self.named_parameters():
param.data.copy_(torch.from_numpy(kernel))
class ImagesDataset(Dataset):
def __init__(self, root, transforms=None, w=960, h=544):
self.root = root
self.transforms = transforms
self.w = w
self.h = h
self.imgs = sorted(os.listdir(os.path.join(self.root, 'image')))
self.alphas = sorted(os.listdir(os.path.join(self.root, 'alpha')))
assert len(self.imgs) == len(self.alphas), 'the number of dataset is different, please check it.'
def get_trimap(self, alpha):
# alpha \in [0, 1] should be taken into account
# be careful when dealing with regions of alpha=0 and alpha=1
fg = np.array(np.equal(alpha, 255).astype(np.float32))
unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) # unknown = alpha > 0
unknown = unknown - fg
# image dilation implemented by Euclidean distance transform
unknown = morphology.distance_transform_edt(unknown==0) <= np.random.randint(1, 20)
trimap = fg * 255
trimap[unknown] = 128
return trimap.astype(np.uint8)
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = cv2.imread(os.path.join(self.root, 'image', self.imgs[idx]))
alpha = cv2.imread(os.path.join(self.root, 'alpha', self.alphas[idx]))
trimap = self.get_trimap(alpha)
# cv2.imshow('trimap', trimap)
# cv2.waitKey(0)
h, w, c = img.shape
if not (w == self.w and h == self.h):
img = cv2.resize(img, (self.w, self.h))
trimap = cv2.resize(trimap, (self.w, self.h))
alpha = cv2.resize(alpha, (self.w, self.h))
if self.transforms:
img = self.transforms(img)
trimap = self.transforms(trimap)
alpha = self.transforms(alpha)
return img, trimap, alpha
blurer = GaussianBlurLayer(3, 3).cuda()
def supervised_training_iter(
modnet, optimizer, image, trimap, gt_matte,
semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
""" Supervised training iteration of MODNet
This function trains MODNet for one iteration in a labeled dataset.
Arguments:
modnet (torch.nn.Module): instance of MODNet
optimizer (torch.optim.Optimizer): optimizer for supervised training
image (torch.autograd.Variable): input RGB image
trimap (torch.autograd.Variable): trimap used to calculate the losses
NOTE: foreground=1, background=0, unknown=0.5
gt_matte (torch.autograd.Variable): ground truth alpha matte
semantic_scale (float): scale of the semantic loss
NOTE: please adjust according to your dataset
detail_scale (float): scale of the detail loss
NOTE: please adjust according to your dataset
matte_scale (float): scale of the matte loss
NOTE: please adjust according to your dataset
Returns:
semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]
Example:
import torch
from src.models.modnet import MODNet
from src.trainer import supervised_training_iter
bs = 16 # batch size
lr = 0.01 # learn rate
epochs = 40 # total epochs
modnet = torch.nn.DataParallel(MODNet()).cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss = \
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
lr_scheduler.step()
"""
global blurer
# set the model to train mode and clear the optimizer
modnet.train()
optimizer.zero_grad()
# forward the model
pred_semantic, pred_detail, pred_matte = modnet(image, False)
# calculate the boundary mask from the trimap
boundaries = (trimap < 0.5) + (trimap > 0.5)
# calculate the semantic loss
gt_semantic = F.interpolate(gt_matte, scale_factor=1 / 16, mode='bilinear')
gt_semantic = blurer(gt_semantic)
semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
semantic_loss = semantic_scale * semantic_loss
# calculate the detail loss
pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
gt_detail = torch.where(boundaries, trimap, gt_matte)
detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
detail_loss = detail_scale * detail_loss
# calculate the matte loss
pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
+ 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
matte_loss = matte_scale * matte_loss
# calculate the final loss, backward the loss, and update the model
loss = semantic_loss + detail_loss + matte_loss
loss.backward()
optimizer.step()
# for test
return semantic_loss, detail_loss, matte_loss
def soc_adaptation_iter(
modnet, backup_modnet, optimizer, image,
soc_semantic_scale=100.0, soc_detail_scale=1.0):
""" Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet
This function fine-tunes MODNet for one iteration in an unlabeled dataset.
Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been
trained in a labeled dataset.
Arguments:
modnet (torch.nn.Module): instance of MODNet
backup_modnet (torch.nn.Module): backup of the trained MODNet
optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC
image (torch.autograd.Variable): input RGB image
soc_semantic_scale (float): scale of the SOC semantic loss
NOTE: please adjust according to your dataset
soc_detail_scale (float): scale of the SOC detail loss
NOTE: please adjust according to your dataset
Returns:
soc_semantic_loss (torch.Tensor): loss of the semantic SOC
soc_detail_loss (torch.Tensor): loss of the detail SOC
Example:
import copy
import torch
from src.models.modnet import MODNet
from src.trainer import soc_adaptation_iter
bs = 1 # batch size
lr = 0.00001 # learn rate
epochs = 10 # total epochs
modnet = torch.nn.DataParallel(MODNet()).cuda()
modnet = LOAD_TRAINED_CKPT() # NOTE: please finish this function
optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99))
dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
for epoch in range(0, epochs):
backup_modnet = copy.deepcopy(modnet)
for idx, (image) in enumerate(dataloader):
soc_semantic_loss, soc_detail_loss = \
soc_adaptation_iter(modnet, backup_modnet, optimizer, image)
"""
global blurer
# set the backup model to eval mode
backup_modnet.eval()
# set the main model to train mode and freeze its norm layers
modnet.train()
modnet.module.freeze_norm()
# clear the optimizer
optimizer.zero_grad()
# forward the main model
pred_semantic, pred_detail, pred_matte = modnet(image, False)
# forward the backup model
with torch.no_grad():
_, pred_backup_detail, pred_backup_matte = backup_modnet(image, False)
# calculate the boundary mask from `pred_matte` and `pred_semantic`
pred_matte_fg = (pred_matte.detach() > 0.1).float()
pred_semantic_fg = (pred_semantic.detach() > 0.1).float()
pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear')
pred_fg = pred_matte_fg * pred_semantic_fg
n, c, h, w = pred_matte.shape
np_pred_fg = pred_fg.data.cpu().numpy()
np_boundaries = np.zeros([n, c, h, w])
for sdx in range(0, n):
sample_np_boundaries = np_boundaries[sdx, 0, ...]
sample_np_pred_fg = np_pred_fg[sdx, 0, ...]
side = int((h + w) / 2 * 0.05)
dilated = grey_dilation(sample_np_pred_fg, size=(side, side))
eroded = grey_erosion(sample_np_pred_fg, size=(side, side))
sample_np_boundaries[np.where(dilated - eroded != 0)] = 1
np_boundaries[sdx, 0, ...] = sample_np_boundaries
boundaries = torch.tensor(np_boundaries).float().cuda()
# sub-objectives consistency between `pred_semantic` and `pred_matte`
# generate pseudo ground truth for `pred_semantic`
downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1 / 16, mode='bilinear'))
pseudo_gt_semantic = downsampled_pred_matte.detach()
pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float()
# generate pseudo ground truth for `pred_matte`
pseudo_gt_matte = pred_semantic.detach()
pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float()
# calculate the SOC semantic loss
soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte,
pseudo_gt_matte)
soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss)
# NOTE: using the formulas in our paper to calculate the following losses has similar results
# sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail)
backup_detail_loss = torch.sum(backup_detail_loss, dim=(1, 2, 3)) / torch.sum(boundaries, dim=(1, 2, 3))
backup_detail_loss = torch.mean(backup_detail_loss)
# sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte)
backup_matte_loss = torch.sum(backup_matte_loss, dim=(1, 2, 3)) / torch.sum(boundaries, dim=(1, 2, 3))
backup_matte_loss = torch.mean(backup_matte_loss)
soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss)
# calculate the final loss, backward the loss, and update the model
loss = soc_semantic_loss + soc_detail_loss
loss.backward()
optimizer.step()
return soc_semantic_loss, soc_detail_loss
# ----------------------------------------------------------------------------------
def main(root):
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
GPU = True if torch.cuda.device_count() > 0 else False
if GPU:
print('Use GPU...')
modnet = modnet.cuda()
modnet.load_state_dict(torch.load(pretrained_ckpt))
else:
print('Use CPU...')
modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
modnet.eval()
bs = 1 # batch size
lr = 0.01 # learn rate
epochs = 40 # total epochs
num_workers = 8
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
# dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
dataset = ImagesDataset(root)
dataloader = DataLoader(dataset, batch_size=bs, num_workers=num_workers, pin_memory=True)
for epoch in range(epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
image = np.transpose(image, (0, 3, 1, 2)).float().cuda()
trimap = np.transpose(trimap, (0, 3, 1, 2)).float().cuda()
gt_matte = np.transpose(gt_matte, (0, 3, 1, 2)).float().cuda()
semantic_loss, detail_loss, matte_loss = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
print(f"epoch: {epoch+1}/{epochs} semantic_loss: {semantic_loss}, detail_loss: {detail_loss}, matte_loss: {matte_loss}")
lr_scheduler.step()
if __name__ == '__main__':
path = 'MODNet/dataset'
main(path)
my GPU 2080Ti 11G
where are something wrong?
@ZHKKKe ,
the size of my train videoMatting images is too large ? or I I'm missing some steps when make dataset ?
@luoww1992 Hi, thanks for your attention.
For your questions:
Q1: the loss is large
You need to normalize the ground truth
to [0, 1]
.
Please add transforms
to ImagesDataset
.
Q2: the detail_loss is 0
The pixel values in the loaded trimap
should be 0=backgroud, 0.5=unknown or 1=foreground. Please pre-process it in your dataset.
You can refer to the latest comments for more information:
image (torch.autograd.Variable): input RGB image
its pixel values should be normalized
trimap (torch.autograd.Variable): trimap used to calculate the losses
its pixel values can be 0, 0.5, or 1
(foreground=1, background=0, unknown=0.5)
gt_matte (torch.autograd.Variable): ground truth alpha matte
its pixel values are between [0, 1]
@ZHKKKe i am training it, and some warns:
UserWarning: Using a target size (torch.Size([4, 3, 36, 64])) that is different to the input size (torch.Size([4, 1, 36, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
UserWarning: Using a target size (torch.Size([4, 3, 576, 1024])) that is different to the input size (torch.Size([4, 1, 576, 1024])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
my input images is 1024*576. will it cause some error in training ?
Please change the channel of gt_matte
to 1
@luoww1992 你训练成功了嘛?
@luoww1992 where you able to train it ?
I have made the train dataset with 80k images, i am training it, i run it for 4 days, but it is too slow
@luoww1992 that's great ... can you share the dataset?
@luoww1992 that's great ... can you share the dataset?
sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images
@luoww1992 that's great ... can you share the dataset?
sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images
Could you share the complete code that can be trained correctly? And I would appreciate it if you could send me a few training samples(10-20 is enough). My email is wyking9@163.com, thanks a lot!
Hey @luoww1992 just share the complete trainable code.
@luoww1992 that's great ... can you share the dataset?
sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images
Could you share the complete code that can be trained correctly? And I would appreciate it if you could send me a few training samples(10-20 is enough). My email is wyking9@163.com, thanks a lot!
@luoww1992 for me too shaikhaftab139@gmail.com, thanks
@luoww1992 that's great ... can you share the dataset?
sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images
please share me the complete code that can be trained correctly, many thanks ! mail: andy910389@gmail.com
i have added it in trainDefault.txt ----- 原始邮件 ----- 发件人:Shu-Hao Ye notifications@github.com 收件人:ZHKKKe/MODNet MODNet@noreply.github.com 抄送人:luoww1992 luoww1992@sina.com, Mention mention@noreply.github.com 主题:Re: [ZHKKKe/MODNet] train dataset questions (#57) 日期:2021年02月25日 15点08分
@luoww1992 that's great ... can you share the dataset?
sorry, i can not share it before i get the authorization,
there are some personal images.
you can try to make it by alpha images
please share me the complete code that can be trained correctly, many thanks !
mail: andy910389@gmail.com
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.
i have added it in trainDefault.txt ----- 原始邮件 ----- 发件人:Shu-Hao Ye notifications@github.com 收件人:ZHKKKe/MODNet MODNet@noreply.github.com 抄送人:luoww1992 luoww1992@sina.com, Mention mention@noreply.github.com 主题:Re: [ZHKKKe/MODNet] train dataset questions (#57) 日期:2021年02月25日 15点08分 @luoww1992 that's great ... can you share the dataset? sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images please share me the complete code that can be trained correctly, many thanks ! mail: andy910389@gmail.com — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.
Thank you very much! I will read the code carefully, hope it will work!
now i am runing two types model, 1, default train model -- i use the same size as the size author provided and the modle author provided. 2, myself train model -- i change the size to train the new model from zero it is very slow to train it . i am writing to finish training and start to do SOC step.
@luoww1992
FYI. I trained the model on a Single GPU with batch-size=8
and input-size=512
. The total training time is about 2~3 days on a dataset that contains 100k samples.
@ luoww1992 仅供参考。我使用
batch-size=8
和在单GPU上训练了模型input-size=512
。在包含1万个样本的数据集上,总训练时间约为2到3天。 i am batch-size=4 , gppu=11G and 80k images . it will cost more times. when you finish training, how much does the matte_loss go down to?
@luoww1992 I got the average training matte_loss=0.0364 in the last training epoch. However, the loss value should be different depending on the dataset.
@ZHKKKe i will run the soc step, i notice we only need images to make dataset without alpha images. Don't we need a standard to show us the model is better while there are no alpha images to compare? how many images you use to run SOC?
@luoww1992 Since the dataset for SOC is unlabeled. We should use visual comparison or user study to see which model is better. You can use many video clips (We use 40k frames for our WebCam model) to adapt the model to a specifc data domain, or you can use only one video clips to adapt the model to a specific video.
@ZHKKKe when run soc, i make two dataset for it: 1--choose 40k images from traing dataset 2--make a new dataset for it are there some different between 1 and 2 ?
@luoww1992 SOC is a technical for self-supervised domain adaptation. The goal of it is to generalize the trained MODNet to a new domain without labeled data. For your options: 1--choose 40k images from traing dataset : SOC will not contribute to the performance. Please train the MODNet by the labels directly. 2--make a new dataset for it: If the data in the new unlabeled dataset has the same domain as the labeled training dataset, SOC is unnecessary.
In practice, we usually train the MODNet on the dataset with background replacement. Therefore, the model may perform badly in natural images without background replacement. SOC should improve the trained model if you fine-tune the trained model in unlabeled natural images without background replacement.
@ZHKKKe i have finished all steps, it is good in many images when testing, but there have flicker and jitter in matting image edge in some images。 can we do something to reduce it ? such as: use high resolution images to train modnet; change the color space; change some args in traing or Soc
@luoww1992 did you using COCO-format for training ? I want to use a mask image train but I don't know how to do it
@luoww1992 did you using COCO-format for training ? I want to use a mask image train but I don't know how to do it
no, i make a new dataset to train modnet, i don't use COCO-format dataset to train because i think it is not accurate in detail
@luoww1992 FYI. I trained the model on a Single GPU with
batch-size=8
andinput-size=512
. The total training time is about 2~3 days on a dataset that contains 100k samples.
@ZHKKKe would you mind providing more details on the training strategy, please?
input-size=512
, do you mean that you rescale the entire image to 512x512
or do you take random 512x512
patches of it (applicable only if your image has some resolution > 512 of course)?Thanks a lot and have a great day ahead!
@FraPochetti
Hi, for your questions:
1. when you say 100k samples, do you mean 100k distinct images + related mattes.
100k samples is generated from 3k hand-annotated foregrounds by compositing each foreground with about 30 different backgrounds.
2. I understand from the paper that you have 3k hand-annotated images. Do you get to 100k by pasting the extracted foregrounds onto a randomly chosen set of backgrounds (as most other papers do)?
Yes, you are correct.
3. by input-size=512, do you mean that you rescale the entire image to 512x512 or do you take random 512x512 patches of it (applicable only if your image has some resolution > 512 of course)?
We composited the training set with the images of size 512x512 directly. During composition, we first rescale each foregournd to the size between $384~768$ randomly (if the rescaled size > 512, cropping is required). We then composited the rescaled foregournd with the backgrounds.
4. do you use any specific augmentation, other than hflip, color jittering, etc?
Nope. In our case, adding Gaussian noise or Gaussian blur will decrease the performance.
@ZHKKKe thanks a lot!
@luoww1992: can you share trainDefault.txt code with SOC. Thanks a lot!
@ZHKKKe i am training it, and some warns:
UserWarning: Using a target size (torch.Size([4, 3, 36, 64])) that is different to the input size (torch.Size([4, 1, 36, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic)) UserWarning: Using a target size (torch.Size([4, 3, 576, 1024])) that is different to the input size (torch.Size([4, 1, 576, 1024])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
my input images is 1024*576.
will it cause some error in training ?
Can it be used for other task, such as sky seg? want to sement the tree more accurately. I trained the MODNet successfully and all the three loss for datasets smaller than 0.02 as picture shows. However, the inference result looks really bad.
@luoww1992 @ZHKKKe Could you pls give me some advice?
@dzyjjpy For your questions: Q1: will it cause some error in training ? It seems that you use a ground truth with three channels for training. However, the outputs of the model have only one channel. Please check your ground truth.
Q2: However, the inference result looks really bad. Can you share your ground truth for training?
@dzyjjpy For your questions: Q1: will it cause some error in training ? It seems that you use a ground truth with three channels for training. However, the outputs of the model have only one channel. Please check your ground truth. The groud truth has only one channel
Q2: However, the inference result looks really bad. Can you share your ground truth for training? As the png file I attach.
@ZHKKKe can you explain for me 2 question:
@ZHKKKe
Hi, Thanks for your sharing. I have a question that my detail loss is still 0 although I have made the value of trimap to be 0, 0.5 or 1, and the value of gr_matte to be between [0, 1]. Which point should I check? Thanks a lot!
@dzyjjpy I guess the ground truth needs to be turned over in value possibly if you want to cut the sky as the background. I mean the value of the sky should be zero and others be 1. Another point is that the area which is out of the fisheye part may impact the feature extraction. In addition, the backbone part is setting as segment human and I also suggest paying attention to the background matting v2 for your problem (for example, use a single color that is collecting from the sky part as the background image). These are just my opinions and It may not be working for your problem, but I just hope they can be useful for you. Thanks
@ZHKKKe Thanks and I have solved it. That is my error that I generated the trimap with 128 / 255 which is not equal to 0.5
@luoww1992
i have finished all steps, it is good in many images when testing, but there have flicker and jitter in matting image edge in some images。 can we do something to reduce it ? such as: use high resolution images to train modnet; change the color space; change some args in traing or Soc 老哥,我在使用SOC训练的时候,损失一直为NAN,可以参考一下你SOC部分的代码吗
i have added it in trainDefault.txt ----- 原始邮件 ----- 发件人:Shu-Hao Ye notifications@github.com 收件人:ZHKKKe/MODNet MODNet@noreply.github.com 抄送人:luoww1992 luoww1992@sina.com, Mention mention@noreply.github.com 主题:Re: [ZHKKKe/MODNet] train dataset questions (#57) 日期:2021年02月25日 15点08分 @luoww1992 that's great ... can you share the dataset? sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images please share me the complete code that can be trained correctly, many thanks ! mail: andy910389@gmail.com — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.
Thank you very much! I will read the code carefully, hope it will work!
Could you share the complete code that can be trained correctly? And I would appreciate it. My email is 755976168@qq.com, thanks a lot!
look up, there is a link trainDefault.txt. Or you can search it in current page with 'Ctrl+F'.
@luoww1992
look up, there is a link trainDefault.txt. Or you can search it in current page with 'Ctrl+F'. That's right, I have found that in early. so,would you mind tell me the image shape in your training datasets when you have a try soc step
@luoww1992 when your using soc step,which model your update the parameters. modnet or backup_modnet?
@upperblacksmith , i save two models at first, then i test them, i can't find the different between the models.
@upperblacksmith , my dataset size is 2048*1152, when training, the size is half. and the loss, maybe you have no some steps: Q1:You need to normalize the ground truth to [0, 1]. Please add transforms to ImagesDataset.
Q2: the detail_loss is 0 The pixel values in the loaded trimap should be 0=backgroud, 0.5=unknown or 1=foreground. Please pre-process it in your dataset.
You can refer to the latest comments for more information:
image (torch.autograd.Variable): input RGB image
its pixel values should be normalized
trimap (torch.autograd.Variable): trimap used to calculate the losses
its pixel values can be 0, 0.5, or 1
(foreground=1, background=0, unknown=0.5)
gt_matte (torch.autograd.Variable): ground truth alpha matte
its pixel values are between [0, 1]
@ZHKKKe what do you think of the code in the picture when using soc step,i can't understand how to initialize modnet and backup_modnet.I will appreciate for your advices.
i am making the train dataset, it needs 3 folders--original, trimap, matter. so the size of image must be 512,? and the image need to do other operations, like change the color and so on ? what else should I pay attention to do in the dataset ?