QLYoo / LFPNet

GNU General Public License v3.0
41 stars 10 forks source link

About training code #6

Open EricLe-dev opened 3 years ago

EricLe-dev commented 3 years ago

Fantastic paper with interesting results I would say. I really want to test the learning capability of this model. Do you plan to release the training code? Thank you so much.

Windaway commented 3 years ago

Sorry, the training code of LFPNet v1 is dirty, and we do not plan to release the code.

EricLe-dev commented 3 years ago

@Windaway Thank you so much for the quick reply. I'm trying to reproduce your results. Can you please share with me the code of the loss functions that you used? Thank you so much.

Windaway commented 3 years ago

The laplace loss in the loss function is as follows. Other cost functions can be easily obtained.

from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F

def gaussian(window_size, sigma):
    x = torch.arange(window_size).float() - window_size // 2
    if window_size % 2 == 0:
        x = x + 0.5
    gauss = torch.exp((-x.pow(2.0) / float(2 * sigma ** 2)))
    return gauss / gauss.sum()

def get_gaussian_kernel1d(kernel_size: int,
                          sigma: float,
                          force_even: bool = False) -> torch.Tensor:
    if (not isinstance(kernel_size, int) or (
            (kernel_size % 2 == 0) and not force_even) or (
            kernel_size <= 0)):
        raise TypeError(
            "kernel_size must be an odd positive integer. "
            "Got {}".format(kernel_size)
        )
    window_1d: torch.Tensor = gaussian(kernel_size, sigma)
    return window_1d

def get_gaussian_kernel2d(
        kernel_size: Tuple[int, int],
        sigma: Tuple[float, float],
        force_even: bool = False) -> torch.Tensor:
    if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
        raise TypeError(
            "kernel_size must be a tuple of length two. Got {}".format(
                kernel_size
            )
        )
    if not isinstance(sigma, tuple) or len(sigma) != 2:
        raise TypeError(
            "sigma must be a tuple of length two. Got {}".format(sigma)
        )
    ksize_x, ksize_y = kernel_size
    sigma_x, sigma_y = sigma
    kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even)
    kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even)
    kernel_2d: torch.Tensor = torch.matmul(
        kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t()
    )
    return kernel_2d

def compute_padding(kernel_size: Tuple[int, int]) -> List[int]:
    assert len(kernel_size) == 2, kernel_size
    computed = [k // 2 for k in kernel_size]

    return [computed[1] - 1 if kernel_size[0] % 2 == 0 else computed[1],
            computed[1],
            computed[0] - 1 if kernel_size[1] % 2 == 0 else computed[0],
            computed[0]]

def filter2D(input: torch.Tensor, kernel: torch.Tensor,
             border_type: str = 'reflect') -> torch.Tensor:
    if not isinstance(input, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}"
                        .format(type(input)))

    if not isinstance(kernel, torch.Tensor):
        raise TypeError("Input kernel type is not a torch.Tensor. Got {}"
                        .format(type(kernel)))

    if not isinstance(border_type, str):
        raise TypeError("Input border_type is not string. Got {}"
                        .format(type(kernel)))

    if not len(input.shape) == 4:
        raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
                         .format(input.shape))

    if not len(kernel.shape) == 3:
        raise ValueError("Invalid kernel shape, we expect 1xHxW. Got: {}"
                         .format(kernel.shape))

    borders_list: List[str] = ['constant', 'reflect', 'replicate', 'circular']
    if border_type not in borders_list:
        raise ValueError("Invalid border_type, we expect the following: {0}."
                         "Got: {1}".format(borders_list, border_type))

    b, c, h, w = input.shape
    tmp_kernel: torch.Tensor = kernel.unsqueeze(0).to(input.device).to(input.dtype)
    height, width = tmp_kernel.shape[-2:]
    padding_shape: List[int] = compute_padding((height, width))
    input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type)
    b, c, hp, wp = input_pad.shape
    kernel_numel: int = height * width
    if kernel_numel > 81:
        return F.conv2d(input_pad.reshape(b * c, 1, hp, wp), tmp_kernel, padding=0, stride=1).view(b, c, h, w)
    return F.conv2d(input_pad, tmp_kernel.expand(c, -1, -1, -1), groups=c, padding=0, stride=1)

def make_lp(img, kernel, max_levels, pad_type):
    current = img
    pyr = []
    for level in range(max_levels):
        filtered = filter2D(current, kernel, pad_type)
        diff = current - filtered
        pyr.append(diff)
        current = torch.nn.functional.avg_pool2d(filtered, 2)
    pyr.append(current)
    return pyr

class inverse_huber_loss(nn.Module):
    def __init__(self,):
        super(inverse_huber_loss, self).__init__()
    def forward(self, input, target):
        absdiff = torch.abs(input-target)
        C = 0.2*torch.max(absdiff).item()
        return torch.mean(torch.where(absdiff < C, absdiff,(absdiff*absdiff+C*C)/(2*C) ))

class lap_loss(nn.Module):
    def __init__(self, max_levels=5, k_size=(5, 5), sigma=(1.5, 1.5), board_type='reflect', loss_type='L1',
                 loss_multiplier=2,clip=True,clipmin=0.,clipmax=1.,reduction='mean'):
        super(lap_loss, self).__init__()
        self.max_levels = max_levels
        self.k_size = k_size
        self.sigma = sigma
        self.board_type = board_type
        self._gauss_kernel = torch.unsqueeze(get_gaussian_kernel2d(k_size, sigma), dim=0)
        self.clip=clip
        self.clipmin = clipmin
        self.clipmax = clipmax
        loss_list: List[str] = ['L1', 'L2','IHuber']
        self.loss_multiplier = loss_multiplier
        if loss_type not in loss_list:
            raise ValueError("Invalid loss_type, we expect the following: {0}."
                             "Got: {1}".format(loss_list, loss_type))
        self.loss_type = loss_type
        if self.loss_type == 'L1':
            self.loss = nn.L1Loss(reduction=reduction)
        elif self.loss_type == 'L2':
            self.loss = nn.MSELoss(reduction=reduction)
        elif self.loss_type=='IHuber':
            self.loss = inverse_huber_loss()

    def forward(self, input, target):
        if self.clip:
            input=torch.clamp(input,self.clipmin,self.clipmax)
        pyr_input = make_lp(input, self._gauss_kernel, self.max_levels, self.board_type)
        pyr_target = make_lp(target, self._gauss_kernel, self.max_levels, self.board_type)
        losses = []
        mul = 1
        for x in range(self.max_levels):
            losses.append(mul * self.loss(pyr_input[x], pyr_target[x]))
            mul *= self.loss_multiplier
        return sum(losses)

class grad_loss(nn.Module):
    def __init__(self, loss_type='L1',clip=False,clipmin=0.,clipmax=1.,reduction='mean'):
        super(grad_loss, self).__init__()
        loss_list: List[str] = ['L1', 'L2']
        if loss_type not in loss_list:
            raise ValueError("Invalid loss_type, we expect the following: {0}."
                             "Got: {1}".format(loss_list, loss_type))
        self.loss_type = loss_type
        if self.loss_type == 'L1':
            self.loss = nn.L1Loss(reduction=reduction)
        elif self.loss_type == 'L2':
            self.loss = nn.MSELoss(reduction=reduction)
        self.clip=clip
        self.clipmin = clipmin
        self.clipmax = clipmax

    def forward(self, input, target):
        if self.clip:
            input=torch.clamp(input,self.clipmin,self.clipmax)
        inputx_=input[:,:,0:-1,:]-input[:,:,1:,:]
        inputy_=input[:,:,:,0:-1]-input[:,:,:,1:]
        targetx_=target[:,:,0:-1,:]-target[:,:,1:,:]
        targety_=target[:,:,:,0:-1]-target[:,:,:,1:]
        loss=self.loss(inputx_,targetx_)+self.loss(inputy_,targety_)
        return loss

if __name__ == '__main__':
    l = lap_loss().cuda()
    a = torch.randn(5, 1, 128, 128).cuda()
    b = torch.randn(5, 1, 128, 128).cuda()
    c = l(a, a/4)
    print(c)
    c = l(a, a/2)
    print(c)
    c = l(a, a*3/4)
    print(c)
EricLe-dev commented 3 years ago

@Windaway fantastic help, my dear. Very appreciate. I have one more quick question that hopefully, you will share with me about your experience. I'm digging deeper into your paper. In your paper, you state that:

We use the propagating loss and matting loss to train the propagating module and the matting module, respectively.

Did you train these two modules separately? If so, how did you initialize them during the training? This is a bit vague for me. Also, you said that:

the training code of LFPNet v1 is dirty

To be honest, your code is much easier to understand than other code :) If you don't mind, I want to express my appreciation, especially if you can share with me a snippet for this, maybe here or via my email at dungtuan.le@outlook.com

Thank you so much.

Windaway commented 3 years ago

In fact, we train different parts of the network in four stages (including the pre-training stage).
The outputs of the propagating module and the matting module are supervised by the propagating loss and the matting loss. In pre-training stage, we use enlarge images to pretrain the network. In stage 1, we use regular images to train the matting module and feature transform layer of the propagating module.(propagating loss + matting loss) In stage 2, we use regular images to train the matting module and the whole decoder of the propagating module. (propagating loss + matting loss) In stage 3, we use regular images to fine-tune the whole network. (propagating loss + matting loss)

Note that, we delete the context alpha matte output in the model.py to make our code at least look more compact.

VolodymyrAhafonov commented 3 years ago

@Windaway, Hi I have question about training procedure.

The images are randomly cropped to patches of dimensions 768×768, 640×640, 512×512, 448×448, 320×320.

Do you resize this crops to one uniform resolution before passing to the network? Like FBA authors did. They also take different sized crops and then resize them all to 320x320 before passing to the network.

Windaway commented 3 years ago

@VolodymyrAhafonov No, we first perform affine transformation. Then, we crop image patches without resizing.

VolodymyrAhafonov commented 3 years ago

@Windaway, thank you for your clarification! I have one more question about training procedure. What do you mean in this quote?

we use enlarge images to pretrain the network.

Are you just upsampling input images? Or using some other method to enlarge them? Thanks in advance. I am strongly appreciate your openness for questions!

Windaway commented 3 years ago

Yes, we upsample the image with bicubic interpolation, and I remember the scale is 2X.

VolodymyrAhafonov commented 3 years ago

@Windaway, thank you for reply!

mexicantexan commented 1 year ago

@EricLe-dev or @VolodymyrAhafonov could you share your code/results from your reproduction?