JuewenPeng / BokehMe

BokehMe: When Neural Rendering Meets Classical Rendering (CVPR 2022 Oral)
Apache License 2.0
181 stars 9 forks source link

code of training model #4

Closed JialeHu97 closed 2 years ago

JialeHu97 commented 2 years ago

@JuewenPeng Hi, could you provide the code for training the model, I would be very grateful!

JuewenPeng commented 2 years ago

I'm sorry, but we do not consider providing the training code now.

JialeHu97 commented 2 years ago

That's okay, thank you.

JuewenPeng commented 2 years ago

Thanks for your comprehension.

JialeHu97 commented 2 years ago

@JuewenPeng Hello, HUST Alumni, could you provide the training dataset, I would be very grateful!

JuewenPeng commented 2 years ago

I am uploading the data to Baidu Netdisk. It will take some time. BTW, you can refer to our another paper (MPIB: An MPI-Based Bokeh Rendering Framework for Realistic Partial Occlusion Effects) for the details of data generation.

JialeHu97 commented 2 years ago

Ok, thank you!

JialeHu97 commented 2 years ago

@JuewenPeng Hello, could you provide the K and disp_focus in the BLB dataset, I would be very grateful!

JuewenPeng commented 2 years ago

The detailed information is shown in the info.json of each scene directory. You can use it as follows.

file = open(os.path.join(scene_path, 'info.json'), 'r')
info_data = json.load(file)
Ks = info_data['blur_parameters']
focus_distances = info_data['focus_distances']

for i in range(5):
  for j in range(10):
    disp_focus = 1 / focus_distances[j]
    defocus = Ks[i] * (disp - disp_focus)

Briefly, we set 5 K and 10 disp_focus. If you normalize the disp to 0-1 (do the same with disp_focus), K is from 10 to 50, and disp_focus is from 0 to 1.

JuewenPeng commented 2 years ago

Training dataset: https://pan.baidu.com/s/1bTxgBn54kB4xJ4YFoOvcAA?pwd=df72

JialeHu97 commented 2 years ago

Briefly, we set 5 K and 10 disp_focus. If you normalize the disp to 0-1 (do the same with disp_focus), K is from 10 to 50, and disp_focus is from 0 to 1. That means K = [10, 20, 30, 40, 50] and disp_focus =( (1 / focus_distances)-(1 / focus_distances).min)/((1 / focus_distances).max-(1 / focus_distances).min)?

JuewenPeng commented 2 years ago

That means

disp = (disp - disp.min()) / (disp.max() - disp.min())  # 0-1
disp_focus = (1/focus_distances[j] - disp.min()) / (disp.max() - disp.min())  # 0-1
K = Ks[i] * ((disp.max() - disp.min()))  # 10-50
defocus = K * (disp - disp_focus)

But if you just need defocus map, you don't need to do the normalization since the above code is equivalent to

disp_focus = 1/focus_distances[j]
K = Ks[i]
defocus = K * (disp - disp_focus)
JialeHu97 commented 2 years ago

Thank you very much!!!

JialeHu97 commented 2 years ago

@JuewenPeng Hello, could you provide the code to calculate the error map, I would be very grateful!

JuewenPeng commented 2 years ago

We provide the code below. Note that this implementation is a little different from what we did in the original paper. This one is more efficient, and you can directly use it without predicting an error map by ARNet. Plus, you can adjust the two parameters delta1 and delta2 freely.

import torch
import torch.nn as nn
import torch.nn.functional as F
import cupy
import re

def gaussian_blur(x, r, sigma=None):
    r = int(round(r))
    if sigma is None:
        sigma = 0.3 * (r - 1) + 0.8
    x_grid, y_grid = torch.meshgrid(torch.arange(-int(r), int(r) + 1), torch.arange(-int(r), int(r) + 1))
    kernel = torch.exp(-(x_grid ** 2 + y_grid ** 2) / 2 / sigma ** 2)
    kernel = kernel.float() / kernel.sum()
    kernel = kernel.expand(1, 1, 2 * r + 1, 2 * r + 1).to(x.device)
    x = F.pad(x, pad=(r, r, r, r), mode='replicate')
    x = F.conv2d(x, weight=kernel, padding=0)
    return x

kernel_Render_updateOutput = '''

    extern "C" __global__ void kernel_Render_updateOutput(
        const int n,
        const float delta1,
        const float delta2,
        const float threshold1,
        const float threshold2,
        const float* radius_max,     // max blur radius map
        const float* radius_min,     // min blur radius map
        int* error                   // error map
    )
    {
        for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
            const int intN = ( intIndex / SIZE_3(error) / SIZE_2(error) / SIZE_1(error) ) % SIZE_0(error);
            // const int intC = ( intIndex / SIZE_3(error) / SIZE_2(error)                 ) % SIZE_1(error);
            const int intY = ( intIndex / SIZE_3(error)                                 ) % SIZE_2(error);
            const int intX = ( intIndex                                                 ) % SIZE_3(error);

            float fltRadiusMax = VALUE_4(radius_max, intN, 0, intY, intX);
            float fltRadiusMin = VALUE_4(radius_min, intN, 0, intY, intX);

            if ((fltRadiusMax < threshold1) || (fltRadiusMin/fltRadiusMax > threshold2)) {
                continue; 
            }

            for (int intDeltaY = -(int)(fltRadiusMax); intDeltaY <= (int)(fltRadiusMax); ++intDeltaY) {
                for (int intDeltaX = -(int)(fltRadiusMax); intDeltaX <= (int)(fltRadiusMax); ++intDeltaX) {

                    int intNeighborY = intY + intDeltaY;
                    int intNeighborX = intX + intDeltaX;

                    if ((intNeighborY >= 0) && (intNeighborY < SIZE_2(error)) && (intNeighborX >= 0) && (intNeighborX < SIZE_3(error))) {

                        float fltDist = sqrtf((float)(intDeltaY)*(float)(intDeltaY) + (float)(intDeltaX)*(float)(intDeltaX));
                        if (fltDist < fltRadiusMax) {
                            float alpha = fltDist / fltRadiusMax;
                            float beta = fltRadiusMin / fltRadiusMax;
                            // float fltError = (1 - powf(alpha, delta1)) * (1 - powf(beta, delta2));
                            // float fltError = (1 - powf(alpha, delta1)) * (delta2 > beta); // (0.5 + 0.5 * tanhf(50 * (delta2 - beta)));
                            // float fltError = (delta1 > alpha) * (delta2 > beta); // (0.5 + 0.5 * tanhf(50 * (delta2 - beta)));
                            float fltError = (0.5 + 0.5 * tanhf(20 * (delta1 - alpha))) * (0.5 + 0.5 * tanhf(20 * (delta2 - beta)));
                            atomicMax(&error[OFFSET_4(error, intN, 0, intNeighborY, intNeighborX)], int(fltError * 1e8));
                        }                        
                    }
                }
            }
        }
    }

'''

def cupy_kernel(strFunction, objVariables):
    strKernel = globals()[strFunction]

    while True:
        objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)

        if objMatch is None:
            break
        # end

        intArg = int(objMatch.group(2))

        strTensor = objMatch.group(4)
        intSizes = objVariables[strTensor].size()

        strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
    # end

    while True:
        objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)

        if objMatch is None:
            break
        # end

        intArgs = int(objMatch.group(2))
        strArgs = objMatch.group(4).split(',')

        strTensor = strArgs[0]
        intStrides = objVariables[strTensor].stride()
        strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
            intStrides[intArg]) + ')' for intArg in range(intArgs)]

        strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
    # end

    while True:
        objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)

        if objMatch is None:
            break
        # end

        intArgs = int(objMatch.group(2))
        strArgs = objMatch.group(4).split(',')

        strTensor = strArgs[0]
        intStrides = objVariables[strTensor].stride()
        strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
            intStrides[intArg]) + ')' for intArg in range(intArgs)]

        strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
    # end

    return strKernel
# end

# @cupy.util.memoize(for_each_device=True)
@cupy.memoize(for_each_device=True)
def cupy_launch(strFunction, strKernel):
    return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
# end

class _FunctionRender(torch.autograd.Function):
    @staticmethod
    def forward(self, radius_max, radius_min, delta1, delta2):
        # self.save_for_backward()

        threshold1 = min(radius_min.shape[2], radius_min.shape[3]) / 1000
        threshold2 = 0.9

        error = torch.zeros_like(radius_max, dtype=torch.int)

        if error.is_cuda == True:
            n = error.nelement()
            cupy_launch('kernel_Render_updateOutput', cupy_kernel('kernel_Render_updateOutput', {
                'delta1': delta1,
                'delta2': delta2,
                'threshold1': threshold1,
                'threshold2': threshold2,
                'radius_max': radius_max,
                'radius_min': radius_min,
                'error': error,
            }))(
                grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[
                    cupy.int(n),
                    cupy.float32(delta1),
                    cupy.float32(delta2),
                    cupy.float32(threshold1),
                    cupy.float32(threshold2),
                    radius_max.data_ptr(),
                    radius_min.data_ptr(),
                    error.data_ptr(),
                ]
            )

        elif error.is_cuda == False:
            raise NotImplementedError()

        # end

        return error.float() / 1e8
    # end

    # @staticmethod
    # def backward(self, gradBokehCum, gradWeightCum):
    # end

# end

def FunctionRender(radius_max, radius_min, delta1, delta2):
    error = _FunctionRender.apply(radius_max, radius_min, delta1, delta2)

    return error
# end

class ModuleGenError(torch.nn.Module):
    def __init__(self):
        super(ModuleGenError, self).__init__()
    # end

    def forward(self, defocus, delta1, delta2, short_size=384):
        b, _, h, w = defocus.shape
        if short_size:
            h_re = int(round(min(h, max(short_size, short_size * h / w))))
            w_re = int(round(min(w, max(short_size, short_size * w / h))))
            scale = (h * w / h_re / w_re) ** 0.5
            defocus = 1/scale * F.interpolate(defocus, size=(h_re, w_re), mode='bilinear', align_corners=True)
        else:
            h_re = h
            w_re = w

        radius = defocus.abs()
        size = 2
        radius = F.pad(radius, pad=(size, size, size, size), mode='replicate')
        radius = F.unfold(radius, kernel_size=2*size+1)
        radius = radius.reshape(b, -1, h_re, w_re)
        radius_max = radius.max(dim=1, keepdim=True)[0]
        radius_min = radius.min(dim=1, keepdim=True)[0]

        error = FunctionRender(radius_max, radius_min, delta1, delta2)

        error = gaussian_blur(error, 3)

        if short_size:
            error = F.interpolate(error, size=(h, w), mode='bilinear', align_corners=True)

        return error
    # end
# end

if __name__ == '__main__':
    import os
    # os.environ['CUDA_VISIBLE_DEVICES'] = '7'
    defocus = 20 * torch.rand(1, 1, 1080, 1920).cuda()
    module = ModuleGenError().cuda()
    error = module(defocus, delta1=0.9, delta2=0.8, short_size=384)
    print(error)
JialeHu97 commented 2 years ago

OK. Thank you very much!!!

JialeHu97 commented 2 years ago

Hello! I had some problems while training arnet and would like to ask you for some advice. Firstly, ARNet's losses are falling but fluctuating a bit. The loss function does not contain the error map loss term. Another problem is that the bokeh image from ARNet‘s output has colour errors. I don't know if you have encountered these problems and would appreciate some ideas on how to solve them.

JuewenPeng commented 2 years ago

Maybe the model needs more time to train. Also, our training code is based on DeepFocus. You can check if there is something wrong.

JialeHu97 commented 2 years ago

Hello! In the training dataset, bokehme_syn_data, you provided, do I need to normalize the disparity after reading 'disparity.exr'? That is, does the second step below require?

  1. disp = cv2.imread(item["disp"], -1).astype(np.float32)
  2. disp = (disp - disp.min()) / (disp.max() - disp.min())
JuewenPeng commented 2 years ago

No, you don’t need to do that.

JialeHu97 commented 2 years ago

OK,thank you for your reply!

wzfsjtu commented 2 years ago

hi, your work is so great! I have downloaded the traindataset and found that the disparity.exr is 512 3, since bokeh_gt and image is 512 512 *3. Could you please give a guidance on how to use it? Thank you very much!

JialeHu97 commented 2 years ago

You can try disp = cv2.imread(disp_path, -1).astype(np.float32).

wzfsjtu commented 2 years ago

You can try disp = cv2.imread(disp_path, -1).astype(np.float32).

Thanks a lot!!

JialeHu97 commented 2 years ago

@wzfsjtu I trained the model using my own implementation of the training code, but it didn't work very well. I wonder if you have trained the model and how well it works.

JialeHu97 commented 2 years ago

@JuewenPeng Hi! I have a question. If I want to be using a DSLR to create the dataset, how should I determine the parameter K and parameter gamma corresponding to each bokeh image taken by the DSLR camera?

JuewenPeng commented 2 years ago

I think it is really hard. Besides, it is hard to obtain the disparity map, and captured all-in-focus images and bokeh images exist color inconsistency and misalignment, such as EBB! dataset.

JialeHu97 commented 2 years ago

OK, I got it. Thank you!

---Original--- From: @.> Date: Fri, Sep 16, 2022 20:25 PM To: @.>; Cc: @.**@.>; Subject: Re: [JuewenPeng/BokehMe] code of training model (Issue #4)

I think it is really hard. Besides, it is hard to obtain the disparity map, and captured all-in-focus images and bokeh images exist color inconsistency and misalignment, such as EBB! dataset.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

JialeHu97 commented 2 years ago

@JuewenPeng Hi,brother! Have you carried out any pre-processing and post-processing operations to further improve the results? Could you provide the relevant codes? The results I have obtained, both from my individual replication and from using the model you provided, are still a bit short of the metrics in your paper. So I would like to ask you what pre-processing and post-processing operations can be used to further improve the results.

JuewenPeng commented 2 years ago

Do you mean the evaluation on the BLB dataset?

JuewenPeng commented 2 years ago

The pretrained model we provide in this repository is different from that in the original paper, but in our experiment, this one is much better. You can use the following code to evaluate the model on the BLB dataset. Remember to change the data path to your own.

# NOTE
# In the BLB dataset, the maximum values of the all-in-focus image and the bokeh ground truth may be larger than 1.
# The numerical ranges of the predicted bokeh image and the ground truth will be clipped in [0, 1] before evaluation.

import os

import cv2
import numpy as np

import time
import xlwt
import json
import warnings
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

############# import your model #############
from neural_renderer import ARNet, IUNet
from classical_renderer.scatter import ModuleRenderScatter  # circular aperture
#############################################

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def calc_psnr(pred, gt, data_range=1):
    if data_range == 1:
        pred = pred * 255
        gt = gt * 255
    mse = torch.mean((pred - gt) ** 2)
    if mse == 0:
        return float('inf')
    else:
        return 20 * torch.log10(255.0 / torch.sqrt(mse)).item()

def calc_ssim(X, Y, mask=None, data_range=1, size_average=True, win_size=11, win_sigma=1.5, win=None, K=(0.01, 0.03), nonnegative_ssim=False):
    r""" interface of ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,H,W)
        Y (torch.Tensor): a batch of images, (N,C,H,W)
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu

    Returns:
        torch.Tensor: ssim results
    """
    if not X.shape == Y.shape:
        raise ValueError("Input images should have the same dimensions.")

    for d in range(len(X.shape) - 1, 1, -1):
        X = X.squeeze(dim=d)
        Y = Y.squeeze(dim=d)

    if len(X.shape) not in (4, 5):
        raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")

    if not X.type() == Y.type():
        raise ValueError("Input images should have the same dtype.")

    if win is not None:  # set win_size
        win_size = win.shape[-1]

    if not (win_size % 2 == 1):
        raise ValueError("Window size should be odd.")

    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

    ssim_per_channel, cs = _ssim(X, Y, mask=mask, data_range=data_range, win=win, size_average=False, K=K)
    if nonnegative_ssim:
        ssim_per_channel = torch.relu(ssim_per_channel)

    if size_average:
        return ssim_per_channel.mean()
    else:
        return ssim_per_channel.mean(1)

def _fspecial_gauss_1d(size, sigma):
    r"""Create 1-D gauss kernel
    Args:
        size (int): the size of gauss kernel
        sigma (float): sigma of normal distribution

    Returns:
        torch.Tensor: 1D kernel (1 x 1 x size)
    """
    coords = torch.arange(size).to(dtype=torch.float)
    coords -= size // 2

    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
    g /= g.sum()

    return g.unsqueeze(0).unsqueeze(0)

def gaussian_filter(input, win):
    r""" Blur input with 1-D kernel
    Args:
        input (torch.Tensor): a batch of tensors to be blurred
        window (torch.Tensor): 1-D gauss kernel

    Returns:
        torch.Tensor: blurred tensors
    """
    assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
    if len(input.shape) == 4:
        conv = F.conv2d
    elif len(input.shape) == 5:
        conv = F.conv3d
    else:
        raise NotImplementedError(input.shape)

    C = input.shape[1]
    out = input
    for i, s in enumerate(input.shape[2:]):
        if s >= win.shape[-1]:
            out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
        else:
            warnings.warn(
                f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
            )

    return out

def _ssim(X, Y, mask, data_range, win, size_average=True, K=(0.01, 0.03)):

    r""" Calculate ssim index for X and Y

    Args:
        X (torch.Tensor): images
        Y (torch.Tensor): images
        win (torch.Tensor): 1-D gauss kernel
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar

    Returns:
        torch.Tensor: ssim results.
    """
    K1, K2 = K
    # batch, channel, [depth,] height, width = X.shape
    compensation = 1.0

    C1 = (K1 * data_range) ** 2
    C2 = (K2 * data_range) ** 2

    win = win.to(X.device, dtype=X.dtype)

    mu1 = gaussian_filter(X, win)
    mu2 = gaussian_filter(Y, win)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
    sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
    sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)

    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)  # set alpha=beta=gamma=1
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

    if mask == None:
        ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
        cs = torch.flatten(cs_map, 2).mean(-1)
    else:
        crop_size = int(win.shape[3] // 2)
        mask = mask[:, :, crop_size:-crop_size, crop_size:-crop_size]
        ssim_per_channel = torch.flatten(ssim_map * mask, 2).mean(-1) / torch.flatten(mask, 2).mean(-1)
        cs = torch.flatten(cs_map * mask, 2).mean(-1) / torch.flatten(mask, 2).mean(-1)
    return ssim_per_channel, cs

def style(bold=False, underline=False, italic=False, auto_warp=True, align_h='center', align_v='center'):
    style = xlwt.XFStyle()

    font = xlwt.Font()
    font.bold = bold
    font.underline = underline
    font.italic = italic
    style.font = font

    alignment = xlwt.Alignment()
    if align_h == 'center':
        alignment.horz = xlwt.Alignment.HORZ_CENTER
    elif align_h == 'left':
        alignment.horz = xlwt.Alignment.HORZ_LEFT
    elif align_h == 'right':
        alignment.horz = xlwt.Alignment.HORZ_RIGHT
    else:
        print('error')
        exit(0)
    if align_v == 'center':
        alignment.vert = xlwt.Alignment.VERT_CENTER
    elif align_v == 'top':
        alignment.vert = xlwt.Alignment.VERT_TOP
    elif align_v == 'bottom':
        alignment.vert = xlwt.Alignment.VERT_BOTTOM
    else:
        print('error')
        exit(0)
    alignment.wrap = int(auto_warp)
    style.alignment = alignment

    return style

########################## change to your rendering pipeline ##########################
def gaussian_blur(x, r, sigma=None):
    r = int(round(r))
    if sigma is None:
        sigma = 0.3 * (r - 1) + 0.8
    x_grid, y_grid = torch.meshgrid(torch.arange(-int(r), int(r) + 1), torch.arange(-int(r), int(r) + 1))
    kernel = torch.exp(-(x_grid ** 2 + y_grid ** 2) / 2 / sigma ** 2)
    kernel = kernel.float() / kernel.sum()
    kernel = kernel.expand(1, 1, 2*r+1, 2*r+1).to(x.device)
    x = F.pad(x, pad=(r, r, r, r), mode='replicate')
    x = F.conv2d(x, weight=kernel, padding=0)
    return x

def pipeline(classical_renderer, arnet, iunet, image, defocus, gamma, args):
    bokeh_classical, defocus_dilate = classical_renderer(image**gamma, defocus*args.defocus_scale)

    bokeh_classical = bokeh_classical ** (1/gamma)
    defocus_dilate = defocus_dilate / args.defocus_scale
    gamma = (gamma - args.gamma_min) / (args.gamma_max - args.gamma_min)
    adapt_scale = max(defocus.abs().max().item(), 1)

    image_re = F.interpolate(image, scale_factor=1/adapt_scale, mode='bilinear', align_corners=True)
    defocus_re = 1 / adapt_scale * F.interpolate(defocus, scale_factor=1/adapt_scale, mode='bilinear', align_corners=True)
    bokeh_neural, error_map = arnet(image_re, defocus_re, gamma)
    error_map = F.interpolate(error_map, size=(image.shape[2], image.shape[3]), mode='bilinear', align_corners=True)
    bokeh_neural.clamp_(0, 1e5)

    for scale in range(int(np.log2(adapt_scale))):
        ratio = 2**(scale+1) / adapt_scale
        h_re, w_re = int(ratio * image.shape[2]), int(ratio * image.shape[3])
        image_re = F.interpolate(image, size=(h_re, w_re), mode='bilinear', align_corners=True)
        defocus_re = ratio * F.interpolate(defocus, size=(h_re, w_re), mode='bilinear', align_corners=True)
        defocus_dilate_re = ratio * F.interpolate(defocus_dilate, size=(h_re, w_re), mode='bilinear', align_corners=True)
        bokeh_neural_refine = iunet(image_re, defocus_re.clamp(-1, 1), bokeh_neural, gamma).clamp(0, 1e5)
        mask = gaussian_blur(((defocus_dilate_re < 1) * (defocus_dilate_re > -1)).float(), 0.005 * (defocus_dilate_re.shape[2] + defocus_dilate_re.shape[3]))
        bokeh_neural = mask * bokeh_neural_refine + (1 - mask) * F.interpolate(bokeh_neural, size=(h_re, w_re), mode='bilinear', align_corners=True)

    bokeh_neural_refine = iunet(image, defocus.clamp(-1, 1), bokeh_neural, gamma).clamp(0, 1e5)
    mask = gaussian_blur(((defocus_dilate < 1) * (defocus_dilate > -1)).float(), 0.005 * (defocus_dilate.shape[2] + defocus_dilate.shape[3]))
    bokeh_neural = mask * bokeh_neural_refine + (1 - mask) * F.interpolate(bokeh_neural, size=(image.shape[2], image.shape[3]), mode='bilinear', align_corners=True)

    bokeh_pred = bokeh_classical * (1 - error_map) + bokeh_neural * error_map

    return bokeh_pred.clamp(0, 1), bokeh_classical.clamp(0, 1), bokeh_neural.clamp(0, 1), error_map
#######################################################################################

def main():
    gamma = 2.2

    ############################# change to your settings #############################
    method = 'BokehMe'            # method name
    root = '/data2/pengjuewen/Bokeh/Blender/data'        # path to BLB dataset
    save_root = os.path.join('./BLB', method)  # path to save

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    parser = argparse.ArgumentParser(description='Bokeh Rendering', fromfile_prefix_chars='@')

    parser.add_argument('--defocus_scale', type=float, default=10.)
    parser.add_argument('--gamma_min', type=float, default=1.)
    parser.add_argument('--gamma_max', type=float, default=5.)

    # Model 1
    parser.add_argument('--arnet_shuffle_rate', type=int, default=2)
    parser.add_argument('--arnet_in_channels', type=int, default=5)
    parser.add_argument('--arnet_out_channels', type=int, default=4)
    parser.add_argument('--arnet_middle_channels', type=int, default=128)
    parser.add_argument('--arnet_num_block', type=int, default=3)
    parser.add_argument('--arnet_share_weight', action='store_true')
    parser.add_argument('--arnet_connect_mode', type=str, default='distinct_source')
    parser.add_argument('--arnet_use_bn', action='store_true')
    parser.add_argument('--arnet_activation', type=str, default='elu')

    # Model 2
    parser.add_argument('--iunet_shuffle_rate', type=int, default=2)
    parser.add_argument('--iunet_in_channels', type=int, default=8)
    parser.add_argument('--iunet_out_channels', type=int, default=3)
    parser.add_argument('--iunet_middle_channels', type=int, default=64)
    parser.add_argument('--iunet_num_block', type=int, default=3)
    parser.add_argument('--iunet_share_weight', action='store_true')
    parser.add_argument('--iunet_connect_mode', type=str, default='distinct_source')
    parser.add_argument('--iunet_use_bn', action='store_true')
    parser.add_argument('--iunet_activation', type=str, default='elu')

    # Checkpoint
    parser.add_argument('--arnet_checkpoint_path', type=str, default='./checkpoints/arnet.pth')
    parser.add_argument('--iunet_checkpoint_path', type=str, default='./checkpoints/iunet.pth')

    # Input
    args = parser.parse_args()

    arnet_checkpoint_path = args.arnet_checkpoint_path
    iunet_checkpoint_path = args.iunet_checkpoint_path

    classical_renderer = ModuleRenderScatter().to(device)

    arnet = ARNet(args.arnet_shuffle_rate, args.arnet_in_channels, args.arnet_out_channels, args.arnet_middle_channels,
                  args.arnet_num_block, args.arnet_share_weight, args.arnet_connect_mode, args.arnet_use_bn,
                  args.arnet_activation)
    iunet = IUNet(args.iunet_shuffle_rate, args.iunet_in_channels, args.iunet_out_channels, args.iunet_middle_channels,
                  args.iunet_num_block, args.iunet_share_weight, args.iunet_connect_mode, args.iunet_use_bn,
                  args.iunet_activation)

    arnet.cuda()
    iunet.cuda()

    checkpoint = torch.load(arnet_checkpoint_path)
    arnet.load_state_dict(checkpoint['model'])
    checkpoint = torch.load(iunet_checkpoint_path)
    iunet.load_state_dict(checkpoint['model'])

    arnet.eval()
    iunet.eval()
    ###################################################################################

    os.makedirs(save_root, exist_ok=True)

    scene_lst = [name for name in sorted(os.listdir(root)) if '.' not in name]
    disp_focus_lst = [f'({i+1})' for i in range(10)]
    metric_lst = ['psnr', 'ssim', 'runtime']

    scene_num = len(scene_lst)
    disp_focus_num = len(disp_focus_lst)
    metric_num = len(metric_lst)

    with torch.no_grad():
        for K_idx in range(5):
            scene_metric_avg = np.zeros([metric_num, scene_num])
            disp_focus_metric_avg = np.zeros([metric_num, disp_focus_num])

            # initialize excel
            workbook = xlwt.Workbook(encoding='utf-8')
            worksheets = [workbook.add_sheet(ind, cell_overwrite_ok=True) for ind in metric_lst]

            standard_style = style()
            left_style = style(align_h='left')

            for worksheet in worksheets:
                worksheet.write_merge(0, 1, 0, 1, method, style=standard_style)
                worksheet.write_merge(0, 0, 2, 1 + disp_focus_num, 'Refocused Disparity', style=standard_style)
                for i, disp_focus in enumerate(disp_focus_lst):
                    worksheet.write(1, 2 + i, disp_focus, style=standard_style)
                worksheet.write_merge(2, 1 + scene_num, 0, 0, 'Scene', style=standard_style)
                worksheet.write_merge(2 + scene_num, 2 + scene_num, 0, 1, 'Average', style=standard_style)
                worksheet.write_merge(0, 1, 2 + disp_focus_num, 2 + disp_focus_num, 'Average', style=standard_style)

            for scene_idx in range(scene_num):
                scene_name = scene_lst[scene_idx]
                scene_path = os.path.join(root, scene_name)
                save_scene_path = os.path.join(save_root, scene_name)

                os.makedirs(save_scene_path, exist_ok=True)

                image = cv2.imread(os.path.join(scene_path, 'image.exr'), -1)[..., :3].astype(np.float32) ** (1/gamma)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Input RGB and output RGB by default

                ###### change to the name of corrupted depth map if necessary ######
                depth_name = 'depth.exr'
                ####################################################################

                depth = cv2.imread(os.path.join(scene_path, depth_name), -1)[..., 0].astype(np.float32)
                disp = 1 / depth

                ############# comment it if using tensorflow model #############
                image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
                disp = torch.from_numpy(disp).unsqueeze(0).unsqueeze(0).contiguous().to(device)
                ################################################################

                for metric_idx in range(metric_num):
                    worksheets[metric_idx].write(2+scene_idx, 1, scene_name, style=standard_style)

                file = open(os.path.join(scene_path, 'info.json'), 'r')
                info_data = json.load(file)
                Ks = info_data['blur_parameters']
                focus_distances = info_data['focus_distances']

                for df_idx in range(len(disp_focus_lst)):
                    K = Ks[K_idx]
                    disp_focus = 1 / focus_distances[df_idx]
                    defocus = K * (disp - disp_focus) / args.defocus_scale

                    gt_name = f'bokeh_{K_idx:0>2d}_{df_idx:0>2d}.exr'
                    gt = cv2.imread(os.path.join(scene_path, gt_name), -1)[..., :3].astype(np.float32) ** (1/gamma)
                    gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
                    gt = torch.from_numpy(gt).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)

                    # neglect the runtime of the first inference
                    if scene_idx + df_idx == 0:
                        pipeline(classical_renderer, arnet, iunet, image, defocus, gamma, args)

                    ###### comment it if using pytorch cpu or tensorflow model ######
                    torch.cuda.synchronize()
                    #################################################################
                    start = time.time()

                    pred = pipeline(classical_renderer, arnet, iunet, image, defocus, gamma, args)[0]

                    ###### comment it if using pytorch cpu or tensorflow model ######
                    torch.cuda.synchronize()
                    #################################################################
                    end = time.time()

                    ############ uncomment it if using tensorflow model ############
                    # pred = torch.from_numpy(pred).to(device)
                    ################################################################

                    # evaluation
                    psnr = calc_psnr(pred.clamp(0, 1), gt.clamp(0, 1))
                    ssim = calc_ssim(pred.clamp(0, 1), gt.clamp(0, 1))

                    # save results
                    pred = pred[0].cpu().clone().permute(1, 2, 0).numpy()
                    pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)

                    save_name = f'bokeh_{K_idx:0>2d}_{df_idx:0>2d}.jpg'
                    cv2.imwrite(os.path.join(save_scene_path, save_name), pred * 255)

                    runtime = end - start

                    scene_metric_avg[0, scene_idx] += psnr
                    scene_metric_avg[1, scene_idx] += ssim
                    scene_metric_avg[2, scene_idx] += runtime

                    disp_focus_metric_avg[0, df_idx] += psnr
                    disp_focus_metric_avg[1, df_idx] += ssim
                    disp_focus_metric_avg[2, df_idx] += runtime

                    # write to excel
                    ii = scene_idx + 2
                    jj = df_idx + 2
                    worksheets[0].write(ii, jj, float(psnr), style=left_style)
                    worksheets[1].write(ii, jj, float(ssim), style=left_style)
                    worksheets[2].write(ii, jj, float(runtime), style=left_style)

                    print(f'scene[{scene_idx+1}/{scene_num}]  disp_focus[{df_idx+1}/{disp_focus_num}]  '
                          f'PSNR:{psnr}  SSIM:{ssim}  Runtime: {runtime}')

            scene_metric_avg /= disp_focus_num
            disp_focus_metric_avg /= scene_num
            assert np.abs(scene_metric_avg.mean(axis=1)[0] - disp_focus_metric_avg.mean(axis=1)[0]) < 1e-5
            metric_avg = scene_metric_avg.mean(axis=1)

            for scene_idx in range(scene_num):
                for df_idx in range(disp_focus_num):
                    ii = scene_idx + 2
                    jj = df_idx + 2
                    for metric_idx in range(metric_num):
                        worksheets[metric_idx].write(ii, 2+disp_focus_num, float(scene_metric_avg[metric_idx, scene_idx]), style=left_style)
                        worksheets[metric_idx].write(2+scene_num, jj, float(disp_focus_metric_avg[metric_idx, df_idx]), style=left_style)

            for metric_idx in range(metric_num):
                worksheets[metric_idx].write(2+scene_num, 2+disp_focus_num, float(metric_avg[metric_idx]), style=left_style)

            xls_name = f'evaluation_K={int(10*(K_idx+1))}.xls'
            workbook.save(os.path.join(save_root, xls_name))

    print(f'"{method}" evaluation done!')

if __name__ == '__main__':
    main()
JialeHu97 commented 2 years ago

OK, thank you very much!

JialeHu97 commented 2 years ago

@JuewenPeng Hi! I have two questions about model training. First, since I see that you did not normalize ‘image.exr’ and 'bokeh.exr' after reading them when evaluating BLB, I would like to ask if I need to normalize ‘image.exr’ and 'bokeh.exr' after reading them when training the model. The second query is whether the numerical ranges of the predicted bokeh image and the ground truth will be clipped in [0, 1] before calculating the loss. Looking forward to your answers!

JialeHu97 commented 2 years ago

@JuewenPeng Hi! Could you provide the test dataset EBB400 containing the corresponding disparity map and the refocused disparity? I would be very grateful!

JuewenPeng commented 2 years ago

@JuewenPeng Hi! I have two questions about model training. First, since I see that you did not normalize ‘image.exr’ and 'bokeh.exr' after reading them when evaluating BLB, I would like to ask if I need to normalize ‘image.exr’ and 'bokeh.exr' after reading them when training the model. The second query is whether the numerical ranges of the predicted bokeh image and the ground truth will be clipped in [0, 1] before calculating the loss. Looking forward to your answers!

For the first question, you don't need to normalize the images during the training. Values of all RGB images are supposed to be in the range of [0, 1.5]. For the second question, we clip the values of bokeh images in [0, 1] for all methods since some of them cannot output values out of [0, 1] (our methods can do that), so we think this clipping operation can make the comparisons fairer.

JuewenPeng commented 2 years ago

@JuewenPeng Hi! Could you provide the test dataset EBB400 containing the corresponding disparity map and the refocused disparity? I would be very grateful!

Do you really need that? Honestly, I don't think this experiment is totally fair and necessary since there is color inconsistency and misalignment between the pairs of input image and bokeh image. We do it just to make the whole experiment more complete.

JialeHu97 commented 2 years ago

@JuewenPeng Hi! Could you provide the test dataset EBB400 containing the corresponding disparity map and the refocused disparity? I would be very grateful!

Do you really need that? Honestly, I don't think this experiment is really fair and necessary. Just to make the experiment more complete.

Here's the thing, I want to see the bokeh effect in a real scene. The EBB400 is just right for me for this need. However, generating disparity maps and determining refocused disparity values is a time-consuming task, and it would save me a lot of work if you could provide readily available data.

JuewenPeng commented 2 years ago

OK, let me upload it to the Baidu Netdisk.

JialeHu97 commented 2 years ago

@JuewenPeng Hi! I have two questions about model training. First, since I see that you did not normalize ‘image.exr’ and 'bokeh.exr' after reading them when evaluating BLB, I would like to ask if I need to normalize ‘image.exr’ and 'bokeh.exr' after reading them when training the model. The second query is whether the numerical ranges of the predicted bokeh image and the ground truth will be clipped in [0, 1] before calculating the loss. Looking forward to your answers!

For the first question, you don't need to normalize the images during the training. Values of all RGB images are supposed to be in the range of [0, 1.5]. For the second question, we clip the values of bokeh images in [0, 1] for all methods since some of them cannot output values out of [0, 1] (our methods can do that), so we think this clipping operation can make the comparisons fairer.

Thank you very much for your answers. For the second question, I would like to know if it is necessary to clip the numerical ranges of the predicted bokeh and the gt bokeh before calculating the loss when training the model.

JialeHu97 commented 2 years ago

OK, let me upload it to the Baidu Netdisk.

Thank you very much!!!

JuewenPeng commented 2 years ago

I think it's optional, but in my practice, I didn't clip the predicted values during the training.

JialeHu97 commented 2 years ago

I think it's optional, but in my practice, I didn't clip the predicted values during the training. OK, thank you very much!

JialeHu97 commented 2 years ago

I want to know for the pretrained model you provide in this repository, did you train it using only the train dataset in bokehme_syn_data, and did you add some data from BLB? I use only the train dataset in bokehme_syn_data to train the model, and when testing the BLB dataset, the SSIM metric can only reach 0.97. I wonder if there is anything I need to pay attention to in the training process

JuewenPeng commented 2 years ago

We only trained our model on the synthetic dataset.

JuewenPeng commented 2 years ago

EBB400 Baidu Netdisk: https://pan.baidu.com/s/1l3Rug16HEB2uUi3u366vLw?pwd=f7mp

We conduct our experiment using the disparity maps in disparity directory, all of which are predicted by MiDaS. We also provide the disparity maps predicted by DPT in disparity_dpt directory. You can use them if expecting better bokeh rendering effects.

JialeHu97 commented 2 years ago

OK, thank you very much!

quhaoooo commented 1 year ago

EBB400 Baidu Netdisk: https://pan.baidu.com/s/1l3Rug16HEB2uUi3u366vLw?pwd=f7mp

We conduct our experiment using the disparity maps in disparity directory, all of which are predicted by MiDaS. We also provide the disparity maps predicted by DPT in disparity_dpt directory. You can use them if expecting better bokeh rendering effects.

thanks for the ebb400. and there is another question. I can't download the entire EBB training set because the website [(https://competitions.codalab.org/competitions/24716#participate) is not valid. Do you have a backup here? Can you share it

JuewenPeng commented 1 year ago

Sorry for that, but you'd better first register for the competition and then download the entire dataset.

quhaoooo commented 1 year ago

Sorry for that, but you'd better first register for the competition and then download the entire dataset.

But it seems that i can't register for the competition because it ends.

JuewenPeng commented 1 year ago

I remember that one can register for the competition any time.

quhaoooo commented 1 year ago

I remember that one can register for the competition any time.

this ? image