ZHKKKe / MODNet

A Trimap-Free Portrait Matting Solution in Real Time [AAAI 2022]
Apache License 2.0
3.84k stars 636 forks source link

train dataset questions #57

Closed luoww1992 closed 3 years ago

luoww1992 commented 3 years ago

i am making the train dataset, it needs 3 folders--original, trimap, matter. so the size of image must be 512,? and the image need to do other operations, like change the color and so on ? what else should I pay attention to do in the dataset ?

ZHKKKe commented 3 years ago

Hi, thanks for your attention.

For your questions: Q1: it needs 3 folders--original, trimap, matter? Yes. And the trimap can be generated from the matte.

Q2: so the size of image must be 512,? No, you can use any size to train the model.

Q2: and the image need to do other operations, like change the color and so on ? You can use the most common data augmentation to process the training data, e.g., flipping, normalization.

antithing commented 3 years ago

@luoww1992 would you be able to share your successful training approach? did you write a dataloader and a trimap creator? Any tips or code would be greatly apreciated! . Thank you!

luoww1992 commented 3 years ago

@ luoww1992您能分享成功的培训方法吗?您是否编写了数据加载器和Trimap创建器? 任何提示或代码都将不胜感激!。谢谢!

i am making trimap dataset,it will spend some times

.......

i am organizing the code .

luoww1992 commented 3 years ago

@ZHKKKe , i am training dataset for videoMatting , but the loss is very large, and the detail_loss is 0 ,no change. my train.py :

import math
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from functools import reduce

import cv2
import numpy as np
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import grey_dilation, grey_erosion
from scipy.ndimage import morphology
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
from math import *

__all__ = [
    'supervised_training_iter',
    'soc_adaptation_iter',
]

def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expansion, dilation=1):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MobileNetV2(nn.Module):
    def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
        super(MobileNetV2, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [expansion, 24, 2, 2],
            [expansion, 32, 3, 2],
            [expansion, 64, 4, 2],
            [expansion, 96, 3, 1],
            [expansion, 160, 3, 2],
            [expansion, 320, 1, 1],
        ]

        # building first layer
        input_channel = _make_divisible(input_channel * alpha, 8)
        self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel
        self.features = [conv_bn(self.in_channels, input_channel, 2)]

        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = _make_divisible(int(c * alpha), 8)
            for i in range(n):
                if i == 0:
                    self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
                else:
                    self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
                input_channel = output_channel

        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))

        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        if self.num_classes is not None:
            self.classifier = nn.Sequential(
                nn.Dropout(0.2),
                nn.Linear(self.last_channel, num_classes),
            )

        # Initialize weights
        self._init_weights()

    def forward(self, x, feature_names=None):
        # Stage1
        x = reduce(lambda x, n: self.features[n](x), list(range(0, 2)), x)
        # Stage2
        x = reduce(lambda x, n: self.features[n](x), list(range(2, 4)), x)
        # Stage3
        x = reduce(lambda x, n: self.features[n](x), list(range(4, 7)), x)
        # Stage4
        x = reduce(lambda x, n: self.features[n](x), list(range(7, 14)), x)
        # Stage5
        x = reduce(lambda x, n: self.features[n](x), list(range(14, 19)), x)

        # Classification
        if self.num_classes is not None:
            x = x.mean(dim=(2, 3))
            x = self.classifier(x)

        # Output
        return x

    def _load_pretrained_model(self, pretrained_file):
        pretrain_dict = torch.load(pretrained_file, map_location='cpu')
        model_dict = {}
        state_dict = self.state_dict()
        print("[MobileNetV2] Loading pretrained model...")
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
            else:
                print(k, "is ignored")
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

class IBNorm(nn.Module):
    """ Combine Instance Norm and Batch Norm into One Layer
    """

    def __init__(self, in_channels):
        super(IBNorm, self).__init__()
        in_channels = in_channels
        self.bnorm_channels = int(in_channels / 2)
        self.inorm_channels = in_channels - self.bnorm_channels

        self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
        self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)

    def forward(self, x):
        bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
        in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())

        return torch.cat((bn_x, in_x), 1)

class Conv2dIBNormRelu(nn.Module):
    """ Convolution + IBNorm + ReLu
    """

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True,
                 with_ibn=True, with_relu=True):
        super(Conv2dIBNormRelu, self).__init__()

        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size,
                      stride=stride, padding=padding, dilation=dilation,
                      groups=groups, bias=bias)
        ]

        if with_ibn:
            layers.append(IBNorm(out_channels))
        if with_relu:
            layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

class BaseBackbone(nn.Module):
    """ Superclass of Replaceable Backbone Model for Semantic Estimation
    """

    def __init__(self, in_channels):
        super(BaseBackbone, self).__init__()
        self.in_channels = in_channels

        self.model = None
        self.enc_channels = []

    def forward(self, x):
        raise NotImplementedError

    def load_pretrained_ckpt(self):
        raise NotImplementedError

class MobileNetV2Backbone(BaseBackbone):
    """ MobileNetV2 Backbone
    """

    def __init__(self, in_channels):
        super(MobileNetV2Backbone, self).__init__(in_channels)
        self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
        self.enc_channels = [16, 24, 32, 96, 1280]

    def forward(self, x):
        x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
        enc2x = x
        x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
        enc4x = x
        x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
        enc8x = x
        x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
        enc16x = x
        x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
        enc32x = x
        return [enc2x, enc4x, enc8x, enc16x, enc32x]

    def load_pretrained_ckpt(self):
        # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
        ckpt_path = './mobilenetv2_human_seg.ckpt'
        if not os.path.exists(ckpt_path):
            print('cannot find the pretrained mobilenetv2 backbone')
            exit()

        ckpt = torch.load(ckpt_path)
        self.model.load_state_dict(ckpt)

class SEBlock(nn.Module):
    """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
    """

    def __init__(self, in_channels, out_channels, reduction=1):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, int(in_channels // reduction), bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels // reduction), out_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)

        return x * w.expand_as(x)

class LRBranch(nn.Module):
    """ Low Resolution Branch of MODNet
    """

    def __init__(self, backbone):
        super(LRBranch, self).__init__()

        enc_channels = backbone.enc_channels
        self.backbone = backbone
        self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
        self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
        self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
        self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False,
                                        with_relu=False)

    def forward(self, img, inference):
        enc_features = self.backbone.forward(img)
        enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]

        enc32x = self.se_block(enc32x)
        lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
        lr16x = self.conv_lr16x(lr16x)

        lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
        lr8x = self.conv_lr8x(lr8x)
        pred_semantic = None
        if not inference:
            lr = self.conv_lr(lr8x)
            pred_semantic = torch.sigmoid(lr)
        return pred_semantic, lr8x, [enc2x, enc4x]

class HRBranch(nn.Module):
    """ High Resolution Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(HRBranch, self).__init__()

        self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
        self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)

        self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
        self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)

        self.conv_hr4x = nn.Sequential(
            Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr2x = nn.Sequential(
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, enc2x, enc4x, lr8x, inference):
        img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
        img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False)

        enc2x = self.tohr_enc2x(enc2x)
        hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))

        enc4x = self.tohr_enc4x(enc4x)
        hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))

        hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
        hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))

        pred_detail = None
        if not inference:
            hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
            hr = self.conv_hr(torch.cat((hr, img), dim=1))
            pred_detail = torch.sigmoid(hr)

        return pred_detail, hr2x

class FusionBranch(nn.Module):
    """ Fusion Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(FusionBranch, self).__init__()
        self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)

        self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
        self.conv_f = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
            Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, lr8x, hr2x):
        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        lr4x = self.conv_lr4x(lr4x)
        lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)

        f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
        f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
        f = self.conv_f(torch.cat((f, img), dim=1))
        pred_matte = torch.sigmoid(f)

        return pred_matte

class MODNet(nn.Module):
    """ Architecture of MODNet
    """

    def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
        super(MODNet, self).__init__()

        self.in_channels = in_channels
        self.hr_channels = hr_channels
        self.backbone_arch = backbone_arch
        self.backbone_pretrained = backbone_pretrained

        self.backbone = MobileNetV2Backbone(self.in_channels)

        self.lr_branch = LRBranch(self.backbone)
        self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
        self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                self._init_conv(m)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
                self._init_norm(m)

        if self.backbone_pretrained:
            self.backbone.load_pretrained_ckpt()

    def forward(self, img, inference):
        pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)

        pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
        pred_matte = self.f_branch(img, lr8x, hr2x)

        return pred_semantic, pred_detail, pred_matte

    def freeze_norm(self):
        norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
        for m in self.modules():
            for n in norm_types:
                if isinstance(m, n):
                    m.eval()
                    continue

    def _init_conv(self, conv):
        nn.init.kaiming_uniform_(
            conv.weight, a=0, mode='fan_in', nonlinearity='relu')
        if conv.bias is not None:
            nn.init.constant_(conv.bias, 0)

    def _init_norm(self, norm):
        if norm.weight is not None:
            nn.init.constant_(norm.weight, 1)
            nn.init.constant_(norm.bias, 0)

class GaussianBlurLayer(nn.Module):
    """ Add Gaussian Blur to a 4D tensors
    This layer takes a 4D tensor of {N, C, H, W} as input.
    The Gaussian blur will be performed in given channel number (C) splitly.
    """

    def __init__(self, channels, kernel_size):
        """ 
        Arguments:
            channels (int): Channel for input tensor
            kernel_size (int): Size of the kernel used in blurring
        """

        super(GaussianBlurLayer, self).__init__()
        self.channels = channels
        self.kernel_size = kernel_size
        assert self.kernel_size % 2 != 0

        self.op = nn.Sequential(
            nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
            nn.Conv2d(channels, channels, self.kernel_size,
                      stride=1, padding=0, bias=None, groups=channels)
        )

        self._init_kernel()

    def forward(self, x):
        """
        Arguments:
            x (torch.Tensor): input 4D tensor
        Returns:
            torch.Tensor: Blurred version of the input 
        """
        if not len(list(x.shape)) == 4:
            print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
            exit()
        elif not x.shape[1] == self.channels:
            print('In \'GaussianBlurLayer\', the required channel ({0}) is'
                  'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
            exit()

        return self.op(x)

    def _init_kernel(self):
        sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8

        n = np.zeros((self.kernel_size, self.kernel_size))
        i = math.floor(self.kernel_size / 2)
        n[i, i] = 1
        kernel = scipy.ndimage.gaussian_filter(n, sigma)

        for name, param in self.named_parameters():
            param.data.copy_(torch.from_numpy(kernel))

class ImagesDataset(Dataset):
    def __init__(self, root, transforms=None, w=960, h=544):
        self.root = root
        self.transforms = transforms
        self.w = w
        self.h = h
        self.imgs = sorted(os.listdir(os.path.join(self.root, 'image')))
        self.alphas = sorted(os.listdir(os.path.join(self.root, 'alpha')))
        assert len(self.imgs) == len(self.alphas), 'the number of dataset is different, please check it.'

    def get_trimap(self, alpha):
        # alpha \in [0, 1] should be taken into account
        # be careful when dealing with regions of alpha=0 and alpha=1
        fg = np.array(np.equal(alpha, 255).astype(np.float32))
        unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) # unknown = alpha > 0
        unknown = unknown - fg
        # image dilation implemented by Euclidean distance transform
        unknown = morphology.distance_transform_edt(unknown==0) <= np.random.randint(1, 20)
        trimap = fg * 255
        trimap[unknown] = 128
        return trimap.astype(np.uint8)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = cv2.imread(os.path.join(self.root, 'image', self.imgs[idx]))
        alpha = cv2.imread(os.path.join(self.root, 'alpha', self.alphas[idx]))
        trimap = self.get_trimap(alpha)
        # cv2.imshow('trimap', trimap)
        # cv2.waitKey(0)
        h, w, c = img.shape
        if not (w == self.w and h == self.h):
            img = cv2.resize(img, (self.w, self.h))
            trimap = cv2.resize(trimap, (self.w, self.h))
            alpha = cv2.resize(alpha, (self.w, self.h))
        if self.transforms:
            img = self.transforms(img)
            trimap = self.transforms(trimap)
            alpha = self.transforms(alpha)
        return img, trimap, alpha

blurer = GaussianBlurLayer(3, 3).cuda()

def supervised_training_iter(
        modnet, optimizer, image, trimap, gt_matte,
        semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
    """ Supervised training iteration of MODNet
    This function trains MODNet for one iteration in a labeled dataset.

    Arguments:
        modnet (torch.nn.Module): instance of MODNet
        optimizer (torch.optim.Optimizer): optimizer for supervised training 
        image (torch.autograd.Variable): input RGB image
        trimap (torch.autograd.Variable): trimap used to calculate the losses
                                          NOTE: foreground=1, background=0, unknown=0.5
        gt_matte (torch.autograd.Variable): ground truth alpha matte
        semantic_scale (float): scale of the semantic loss
                                NOTE: please adjust according to your dataset
        detail_scale (float): scale of the detail loss
                              NOTE: please adjust according to your dataset
        matte_scale (float): scale of the matte loss
                             NOTE: please adjust according to your dataset

    Returns:
        semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
        detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
        matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]

    Example:
        import torch
        from src.models.modnet import MODNet
        from src.trainer import supervised_training_iter

        bs = 16         # batch size
        lr = 0.01       # learn rate
        epochs = 40     # total epochs

        modnet = torch.nn.DataParallel(MODNet()).cuda()
        optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

        dataloader = CREATE_YOUR_DATALOADER(bs)     # NOTE: please finish this function

        for epoch in range(0, epochs):
            for idx, (image, trimap, gt_matte) in enumerate(dataloader):
                semantic_loss, detail_loss, matte_loss = \
                    supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
            lr_scheduler.step()
    """

    global blurer

    # set the model to train mode and clear the optimizer
    modnet.train()
    optimizer.zero_grad()

    # forward the model
    pred_semantic, pred_detail, pred_matte = modnet(image, False)

    # calculate the boundary mask from the trimap
    boundaries = (trimap < 0.5) + (trimap > 0.5)

    # calculate the semantic loss
    gt_semantic = F.interpolate(gt_matte, scale_factor=1 / 16, mode='bilinear')
    gt_semantic = blurer(gt_semantic)
    semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
    semantic_loss = semantic_scale * semantic_loss

    # calculate the detail loss
    pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
    gt_detail = torch.where(boundaries, trimap, gt_matte)
    detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
    detail_loss = detail_scale * detail_loss

    # calculate the matte loss
    pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
    matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
    matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
                               + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
    matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
    matte_loss = matte_scale * matte_loss

    # calculate the final loss, backward the loss, and update the model 
    loss = semantic_loss + detail_loss + matte_loss
    loss.backward()
    optimizer.step()

    # for test
    return semantic_loss, detail_loss, matte_loss

def soc_adaptation_iter(
        modnet, backup_modnet, optimizer, image,
        soc_semantic_scale=100.0, soc_detail_scale=1.0):
    """ Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet
    This function fine-tunes MODNet for one iteration in an unlabeled dataset.
    Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been 
    trained in a labeled dataset.

    Arguments:
        modnet (torch.nn.Module): instance of MODNet
        backup_modnet (torch.nn.Module): backup of the trained MODNet
        optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC 
        image (torch.autograd.Variable): input RGB image
        soc_semantic_scale (float): scale of the SOC semantic loss 
                                    NOTE: please adjust according to your dataset
        soc_detail_scale (float): scale of the SOC detail loss
                                  NOTE: please adjust according to your dataset

    Returns:
        soc_semantic_loss (torch.Tensor): loss of the semantic SOC
        soc_detail_loss (torch.Tensor): loss of the detail SOC

    Example:
        import copy
        import torch
        from src.models.modnet import MODNet
        from src.trainer import soc_adaptation_iter

        bs = 1          # batch size
        lr = 0.00001    # learn rate
        epochs = 10     # total epochs

        modnet = torch.nn.DataParallel(MODNet()).cuda()
        modnet = LOAD_TRAINED_CKPT()    # NOTE: please finish this function

        optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99))
        dataloader = CREATE_YOUR_DATALOADER(bs)     # NOTE: please finish this function

        for epoch in range(0, epochs):
            backup_modnet = copy.deepcopy(modnet)
            for idx, (image) in enumerate(dataloader):
                soc_semantic_loss, soc_detail_loss = \
                    soc_adaptation_iter(modnet, backup_modnet, optimizer, image)
    """

    global blurer

    # set the backup model to eval mode
    backup_modnet.eval()

    # set the main model to train mode and freeze its norm layers
    modnet.train()
    modnet.module.freeze_norm()

    # clear the optimizer
    optimizer.zero_grad()

    # forward the main model
    pred_semantic, pred_detail, pred_matte = modnet(image, False)

    # forward the backup model
    with torch.no_grad():
        _, pred_backup_detail, pred_backup_matte = backup_modnet(image, False)

    # calculate the boundary mask from `pred_matte` and `pred_semantic`
    pred_matte_fg = (pred_matte.detach() > 0.1).float()
    pred_semantic_fg = (pred_semantic.detach() > 0.1).float()
    pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear')
    pred_fg = pred_matte_fg * pred_semantic_fg

    n, c, h, w = pred_matte.shape
    np_pred_fg = pred_fg.data.cpu().numpy()
    np_boundaries = np.zeros([n, c, h, w])
    for sdx in range(0, n):
        sample_np_boundaries = np_boundaries[sdx, 0, ...]
        sample_np_pred_fg = np_pred_fg[sdx, 0, ...]

        side = int((h + w) / 2 * 0.05)
        dilated = grey_dilation(sample_np_pred_fg, size=(side, side))
        eroded = grey_erosion(sample_np_pred_fg, size=(side, side))

        sample_np_boundaries[np.where(dilated - eroded != 0)] = 1
        np_boundaries[sdx, 0, ...] = sample_np_boundaries

    boundaries = torch.tensor(np_boundaries).float().cuda()

    # sub-objectives consistency between `pred_semantic` and `pred_matte`
    # generate pseudo ground truth for `pred_semantic`
    downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1 / 16, mode='bilinear'))
    pseudo_gt_semantic = downsampled_pred_matte.detach()
    pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float()

    # generate pseudo ground truth for `pred_matte`
    pseudo_gt_matte = pred_semantic.detach()
    pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float()

    # calculate the SOC semantic loss
    soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte,
                                                                                   pseudo_gt_matte)
    soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss)

    # NOTE: using the formulas in our paper to calculate the following losses has similar results
    # sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
    backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail)
    backup_detail_loss = torch.sum(backup_detail_loss, dim=(1, 2, 3)) / torch.sum(boundaries, dim=(1, 2, 3))
    backup_detail_loss = torch.mean(backup_detail_loss)

    # sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
    backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte)
    backup_matte_loss = torch.sum(backup_matte_loss, dim=(1, 2, 3)) / torch.sum(boundaries, dim=(1, 2, 3))
    backup_matte_loss = torch.mean(backup_matte_loss)

    soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss)

    # calculate the final loss, backward the loss, and update the model 
    loss = soc_semantic_loss + soc_detail_loss

    loss.backward()
    optimizer.step()

    return soc_semantic_loss, soc_detail_loss

# ----------------------------------------------------------------------------------

def main(root):
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet)
    GPU = True if torch.cuda.device_count() > 0 else False
    if GPU:
        print('Use GPU...')
        modnet = modnet.cuda()
        modnet.load_state_dict(torch.load(pretrained_ckpt))
    else:
        print('Use CPU...')
        modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
    modnet.eval()
    bs = 1  # batch size
    lr = 0.01  # learn rate
    epochs = 40  # total epochs
    num_workers = 8
    optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

    # dataloader = CREATE_YOUR_DATALOADER(bs)  # NOTE: please finish this function
    dataset = ImagesDataset(root)
    dataloader = DataLoader(dataset, batch_size=bs, num_workers=num_workers, pin_memory=True)

    for epoch in range(epochs):
        for idx, (image, trimap, gt_matte) in enumerate(dataloader):
            image = np.transpose(image, (0, 3, 1, 2)).float().cuda()
            trimap = np.transpose(trimap, (0, 3, 1, 2)).float().cuda()
            gt_matte = np.transpose(gt_matte, (0, 3, 1, 2)).float().cuda()
            semantic_loss, detail_loss, matte_loss = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
            print(f"epoch: {epoch+1}/{epochs} semantic_loss: {semantic_loss}, detail_loss: {detail_loss}, matte_loss: {matte_loss}")
        lr_scheduler.step()

if __name__ == '__main__':
    path = 'MODNet/dataset'
    main(path)

the loss show epoch: 21/40 semantic_loss: 267474.09375, detail_loss: 0.0, matte_loss: 11912.3251953125 epoch: 21/40 semantic_loss: 152139.375, detail_loss: 0.0, matte_loss: 9059.408203125 epoch: 21/40 semantic_loss: 314421.375, detail_loss: 0.0, matte_loss: 10369.087890625

my GPU 2080Ti 11G

where are something wrong?

luoww1992 commented 3 years ago

@ZHKKKe ,
the size of my train videoMatting images is too large ? or I I'm missing some steps when make dataset ?

ZHKKKe commented 3 years ago

@luoww1992 Hi, thanks for your attention.

For your questions:

Q1: the loss is large You need to normalize the ground truth to [0, 1]. Please add transforms to ImagesDataset.

Q2: the detail_loss is 0 The pixel values in the loaded trimap should be 0=backgroud, 0.5=unknown or 1=foreground. Please pre-process it in your dataset.

You can refer to the latest comments for more information:

        image (torch.autograd.Variable): input RGB image
                                         its pixel values should be normalized
        trimap (torch.autograd.Variable): trimap used to calculate the losses
                                          its pixel values can be 0, 0.5, or 1
                                          (foreground=1, background=0, unknown=0.5)
        gt_matte (torch.autograd.Variable): ground truth alpha matte
                                            its pixel values are between [0, 1]
luoww1992 commented 3 years ago

@ZHKKKe i am training it, and some warns:

UserWarning: Using a target size (torch.Size([4, 3, 36, 64])) that is different to the input size (torch.Size([4, 1, 36, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
UserWarning: Using a target size (torch.Size([4, 3, 576, 1024])) that is different to the input size (torch.Size([4, 1, 576, 1024])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)

my input images is 1024*576. will it cause some error in training ?

ZHKKKe commented 3 years ago

Please change the channel of gt_matte to 1

Johnson-yue commented 3 years ago

@luoww1992 你训练成功了嘛?

Shk-aftab commented 3 years ago

@luoww1992 where you able to train it ?

luoww1992 commented 3 years ago

I have made the train dataset with 80k images, i am training it, i run it for 4 days, but it is too slow

Shk-aftab commented 3 years ago

@luoww1992 that's great ... can you share the dataset?

luoww1992 commented 3 years ago

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images

czHappy commented 3 years ago

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images

Could you share the complete code that can be trained correctly? And I would appreciate it if you could send me a few training samples(10-20 is enough). My email is wyking9@163.com, thanks a lot!

Shk-aftab commented 3 years ago

Hey @luoww1992 just share the complete trainable code.

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images

Could you share the complete code that can be trained correctly? And I would appreciate it if you could send me a few training samples(10-20 is enough). My email is wyking9@163.com, thanks a lot!

@luoww1992 for me too shaikhaftab139@gmail.com, thanks

luoww1992 commented 3 years ago

trainDefault.txt

i have updated it,

please pay attention to the notes in func.

andy910389 commented 3 years ago

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images

please share me the complete code that can be trained correctly, many thanks ! mail: andy910389@gmail.com

luoww1992 commented 3 years ago

i have added it in trainDefault.txt ----- 原始邮件 ----- 发件人:Shu-Hao Ye notifications@github.com 收件人:ZHKKKe/MODNet MODNet@noreply.github.com 抄送人:luoww1992 luoww1992@sina.com, Mention mention@noreply.github.com 主题:Re: [ZHKKKe/MODNet] train dataset questions (#57) 日期:2021年02月25日 15点08分

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization,

there are some personal images.

you can try to make it by alpha images

please share me the complete code that can be trained correctly, many thanks !

mail: andy910389@gmail.com

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

czHappy commented 3 years ago

i have added it in trainDefault.txt ----- 原始邮件 ----- 发件人:Shu-Hao Ye notifications@github.com 收件人:ZHKKKe/MODNet MODNet@noreply.github.com 抄送人:luoww1992 luoww1992@sina.com, Mention mention@noreply.github.com 主题:Re: [ZHKKKe/MODNet] train dataset questions (#57) 日期:2021年02月25日 15点08分 @luoww1992 that's great ... can you share the dataset? sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images please share me the complete code that can be trained correctly, many thanks ! mail: andy910389@gmail.com — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

Thank you very much! I will read the code carefully, hope it will work!

luoww1992 commented 3 years ago

now i am runing two types model, 1, default train model -- i use the same size as the size author provided and the modle author provided. 2, myself train model -- i change the size to train the new model from zero it is very slow to train it . i am writing to finish training and start to do SOC step.

ZHKKKe commented 3 years ago

@luoww1992 FYI. I trained the model on a Single GPU with batch-size=8 and input-size=512. The total training time is about 2~3 days on a dataset that contains 100k samples.

luoww1992 commented 3 years ago

@ luoww1992 仅供参考。我使用batch-size=8和在单GPU上训练了模型input-size=512。在包含1万个样本的数据集上,总训练时间约为2到3天。 i am batch-size=4 , gppu=11G and 80k images . it will cost more times. when you finish training, how much does the matte_loss go down to?

ZHKKKe commented 3 years ago

@luoww1992 I got the average training matte_loss=0.0364 in the last training epoch. However, the loss value should be different depending on the dataset.

luoww1992 commented 3 years ago

@ZHKKKe i will run the soc step, i notice we only need images to make dataset without alpha images. Don't we need a standard to show us the model is better while there are no alpha images to compare? how many images you use to run SOC?

ZHKKKe commented 3 years ago

@luoww1992 Since the dataset for SOC is unlabeled. We should use visual comparison or user study to see which model is better. You can use many video clips (We use 40k frames for our WebCam model) to adapt the model to a specifc data domain, or you can use only one video clips to adapt the model to a specific video.

luoww1992 commented 3 years ago

@ZHKKKe when run soc, i make two dataset for it: 1--choose 40k images from traing dataset 2--make a new dataset for it are there some different between 1 and 2 ?

ZHKKKe commented 3 years ago

@luoww1992 SOC is a technical for self-supervised domain adaptation. The goal of it is to generalize the trained MODNet to a new domain without labeled data. For your options: 1--choose 40k images from traing dataset : SOC will not contribute to the performance. Please train the MODNet by the labels directly. 2--make a new dataset for it: If the data in the new unlabeled dataset has the same domain as the labeled training dataset, SOC is unnecessary.

In practice, we usually train the MODNet on the dataset with background replacement. Therefore, the model may perform badly in natural images without background replacement. SOC should improve the trained model if you fine-tune the trained model in unlabeled natural images without background replacement.

luoww1992 commented 3 years ago

@ZHKKKe i have finished all steps, it is good in many images when testing, but there have flicker and jitter in matting image edge in some images。 can we do something to reduce it ? such as: use high resolution images to train modnet; change the color space; change some args in traing or Soc

Johnson-yue commented 3 years ago

@luoww1992 did you using COCO-format for training ? I want to use a mask image train but I don't know how to do it

luoww1992 commented 3 years ago

@luoww1992 did you using COCO-format for training ? I want to use a mask image train but I don't know how to do it

no, i make a new dataset to train modnet, i don't use COCO-format dataset to train because i think it is not accurate in detail

FraPochetti commented 3 years ago

@luoww1992 FYI. I trained the model on a Single GPU with batch-size=8 and input-size=512. The total training time is about 2~3 days on a dataset that contains 100k samples.

@ZHKKKe would you mind providing more details on the training strategy, please?

  1. when you say 100k samples, do you mean 100k distinct images + related mattes.
  2. I understand from the paper that you have 3k hand-annotated images. Do you get to 100k by pasting the extracted foregrounds onto a randomly chosen set of backgrounds (as most other papers do)?
  3. by input-size=512, do you mean that you rescale the entire image to 512x512 or do you take random 512x512 patches of it (applicable only if your image has some resolution > 512 of course)?
  4. do you use any specific augmentation, other than hflip, color jittering, etc?

Thanks a lot and have a great day ahead!

ZHKKKe commented 3 years ago

@FraPochetti Hi, for your questions: 1. when you say 100k samples, do you mean 100k distinct images + related mattes. 100k samples is generated from 3k hand-annotated foregrounds by compositing each foreground with about 30 different backgrounds. 2. I understand from the paper that you have 3k hand-annotated images. Do you get to 100k by pasting the extracted foregrounds onto a randomly chosen set of backgrounds (as most other papers do)? Yes, you are correct. 3. by input-size=512, do you mean that you rescale the entire image to 512x512 or do you take random 512x512 patches of it (applicable only if your image has some resolution > 512 of course)? We composited the training set with the images of size 512x512 directly. During composition, we first rescale each foregournd to the size between $384~768$ randomly (if the rescaled size > 512, cropping is required). We then composited the rescaled foregournd with the backgrounds.
4. do you use any specific augmentation, other than hflip, color jittering, etc? Nope. In our case, adding Gaussian noise or Gaussian blur will decrease the performance.

FraPochetti commented 3 years ago

@ZHKKKe thanks a lot!

twin-92 commented 3 years ago

@luoww1992: can you share trainDefault.txt code with SOC. Thanks a lot!

dzyjjpy commented 3 years ago

@ZHKKKe i am training it, and some warns:

UserWarning: Using a target size (torch.Size([4, 3, 36, 64])) that is different to the input size (torch.Size([4, 1, 36, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
UserWarning: Using a target size (torch.Size([4, 3, 576, 1024])) that is different to the input size (torch.Size([4, 1, 576, 1024])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)

my input images is 1024*576.

will it cause some error in training ?

Can it be used for other task, such as sky seg? want to sement the tree more accurately. I trained the MODNet successfully and all the three loss for datasets smaller than 0.02 as picture shows. However, the inference result looks really bad. image image

@luoww1992 @ZHKKKe Could you pls give me some advice?

ZHKKKe commented 3 years ago

@dzyjjpy For your questions: Q1: will it cause some error in training ? It seems that you use a ground truth with three channels for training. However, the outputs of the model have only one channel. Please check your ground truth.

Q2: However, the inference result looks really bad. Can you share your ground truth for training?

dzyjjpy commented 3 years ago

@dzyjjpy For your questions: Q1: will it cause some error in training ? It seems that you use a ground truth with three channels for training. However, the outputs of the model have only one channel. Please check your ground truth. The groud truth has only one channel

Q2: However, the inference result looks really bad. Can you share your ground truth for training? As the png file I attach. 0007

twin-92 commented 3 years ago

@ZHKKKe can you explain for me 2 question:

  1. why do you use image input with size 512 and the large is divisible by 32, not use input size 512x512 (I don't see resize image to 512x512 in your inference.py
  2. and you create alpha matting dataset by Photoshop. You will create binary mask of person in image (only 2 values: 0 and 255), then you save this mask, and it is labeled alpha matting of this image? Thank you very much!
Boya-Na commented 3 years ago

@ZHKKKe

Hi, Thanks for your sharing. I have a question that my detail loss is still 0 although I have made the value of trimap to be 0, 0.5 or 1, and the value of gr_matte to be between [0, 1]. Which point should I check? Thanks a lot!

Boya-Na commented 3 years ago

@dzyjjpy I guess the ground truth needs to be turned over in value possibly if you want to cut the sky as the background. I mean the value of the sky should be zero and others be 1. Another point is that the area which is out of the fisheye part may impact the feature extraction. In addition, the backbone part is setting as segment human and I also suggest paying attention to the background matting v2 for your problem (for example, use a single color that is collecting from the sky part as the background image). These are just my opinions and It may not be working for your problem, but I just hope they can be useful for you. Thanks

Boya-Na commented 3 years ago

@ZHKKKe Thanks and I have solved it. That is my error that I generated the trimap with 128 / 255 which is not equal to 0.5

upperblacksmith commented 3 years ago

@luoww1992

i have finished all steps, it is good in many images when testing, but there have flicker and jitter in matting image edge in some images。 can we do something to reduce it ? such as: use high resolution images to train modnet; change the color space; change some args in traing or Soc 老哥,我在使用SOC训练的时候,损失一直为NAN,可以参考一下你SOC部分的代码吗

upperblacksmith commented 3 years ago

i have added it in trainDefault.txt ----- 原始邮件 ----- 发件人:Shu-Hao Ye notifications@github.com 收件人:ZHKKKe/MODNet MODNet@noreply.github.com 抄送人:luoww1992 luoww1992@sina.com, Mention mention@noreply.github.com 主题:Re: [ZHKKKe/MODNet] train dataset questions (#57) 日期:2021年02月25日 15点08分 @luoww1992 that's great ... can you share the dataset? sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images please share me the complete code that can be trained correctly, many thanks ! mail: andy910389@gmail.com — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

Thank you very much! I will read the code carefully, hope it will work!

Could you share the complete code that can be trained correctly? And I would appreciate it. My email is 755976168@qq.com, thanks a lot!

luoww1992 commented 3 years ago

look up, there is a link trainDefault.txt. Or you can search it in current page with 'Ctrl+F'.

upperblacksmith commented 3 years ago

@luoww1992

look up, there is a link trainDefault.txt. Or you can search it in current page with 'Ctrl+F'. That's right, I have found that in early. so,would you mind tell me the image shape in your training datasets when you have a try soc step

upperblacksmith commented 3 years ago

@luoww1992 when your using soc step,which model your update the parameters. modnet or backup_modnet? 666

luoww1992 commented 3 years ago

@upperblacksmith , i save two models at first, then i test them, i can't find the different between the models.

luoww1992 commented 3 years ago

@upperblacksmith , my dataset size is 2048*1152, when training, the size is half. and the loss, maybe you have no some steps: Q1:You need to normalize the ground truth to [0, 1]. Please add transforms to ImagesDataset.

Q2: the detail_loss is 0 The pixel values in the loaded trimap should be 0=backgroud, 0.5=unknown or 1=foreground. Please pre-process it in your dataset.

You can refer to the latest comments for more information:

    image (torch.autograd.Variable): input RGB image
                                     its pixel values should be normalized
    trimap (torch.autograd.Variable): trimap used to calculate the losses
                                      its pixel values can be 0, 0.5, or 1
                                      (foreground=1, background=0, unknown=0.5)
    gt_matte (torch.autograd.Variable): ground truth alpha matte
                                        its pixel values are between [0, 1]
luoww1992 commented 3 years ago

@upperblacksmith , add the code to training code:

run.txt

upperblacksmith commented 3 years ago

@ZHKKKe what do you think of the code in the picture when using soc step,i can't understand how to initialize modnet and backup_modnet.I will appreciate for your advices. 11111