Closed JialeHu97 closed 2 years ago
I'm sorry, but we do not consider providing the training code now.
That's okay, thank you.
Thanks for your comprehension.
@JuewenPeng Hello, HUST Alumni, could you provide the training dataset, I would be very grateful!
I am uploading the data to Baidu Netdisk. It will take some time. BTW, you can refer to our another paper (MPIB: An MPI-Based Bokeh Rendering Framework for Realistic Partial Occlusion Effects) for the details of data generation.
Ok, thank you!
@JuewenPeng Hello, could you provide the K and disp_focus in the BLB dataset, I would be very grateful!
The detailed information is shown in the info.json
of each scene directory. You can use it as follows.
file = open(os.path.join(scene_path, 'info.json'), 'r')
info_data = json.load(file)
Ks = info_data['blur_parameters']
focus_distances = info_data['focus_distances']
for i in range(5):
for j in range(10):
disp_focus = 1 / focus_distances[j]
defocus = Ks[i] * (disp - disp_focus)
Briefly, we set 5 K and 10 disp_focus. If you normalize the disp to 0-1 (do the same with disp_focus), K is from 10 to 50, and disp_focus is from 0 to 1.
Training dataset: https://pan.baidu.com/s/1bTxgBn54kB4xJ4YFoOvcAA?pwd=df72
Briefly, we set 5 K and 10 disp_focus. If you normalize the disp to 0-1 (do the same with disp_focus), K is from 10 to 50, and disp_focus is from 0 to 1. That means K = [10, 20, 30, 40, 50] and disp_focus =( (1 / focus_distances)-(1 / focus_distances).min)/((1 / focus_distances).max-(1 / focus_distances).min)?
That means
disp = (disp - disp.min()) / (disp.max() - disp.min()) # 0-1
disp_focus = (1/focus_distances[j] - disp.min()) / (disp.max() - disp.min()) # 0-1
K = Ks[i] * ((disp.max() - disp.min())) # 10-50
defocus = K * (disp - disp_focus)
But if you just need defocus map, you don't need to do the normalization since the above code is equivalent to
disp_focus = 1/focus_distances[j]
K = Ks[i]
defocus = K * (disp - disp_focus)
Thank you very much!!!
@JuewenPeng Hello, could you provide the code to calculate the error map, I would be very grateful!
We provide the code below. Note that this implementation is a little different from what we did in the original paper. This one is more efficient, and you can directly use it without predicting an error map by ARNet. Plus, you can adjust the two parameters delta1 and delta2 freely.
import torch
import torch.nn as nn
import torch.nn.functional as F
import cupy
import re
def gaussian_blur(x, r, sigma=None):
r = int(round(r))
if sigma is None:
sigma = 0.3 * (r - 1) + 0.8
x_grid, y_grid = torch.meshgrid(torch.arange(-int(r), int(r) + 1), torch.arange(-int(r), int(r) + 1))
kernel = torch.exp(-(x_grid ** 2 + y_grid ** 2) / 2 / sigma ** 2)
kernel = kernel.float() / kernel.sum()
kernel = kernel.expand(1, 1, 2 * r + 1, 2 * r + 1).to(x.device)
x = F.pad(x, pad=(r, r, r, r), mode='replicate')
x = F.conv2d(x, weight=kernel, padding=0)
return x
kernel_Render_updateOutput = '''
extern "C" __global__ void kernel_Render_updateOutput(
const int n,
const float delta1,
const float delta2,
const float threshold1,
const float threshold2,
const float* radius_max, // max blur radius map
const float* radius_min, // min blur radius map
int* error // error map
)
{
for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(error) / SIZE_2(error) / SIZE_1(error) ) % SIZE_0(error);
// const int intC = ( intIndex / SIZE_3(error) / SIZE_2(error) ) % SIZE_1(error);
const int intY = ( intIndex / SIZE_3(error) ) % SIZE_2(error);
const int intX = ( intIndex ) % SIZE_3(error);
float fltRadiusMax = VALUE_4(radius_max, intN, 0, intY, intX);
float fltRadiusMin = VALUE_4(radius_min, intN, 0, intY, intX);
if ((fltRadiusMax < threshold1) || (fltRadiusMin/fltRadiusMax > threshold2)) {
continue;
}
for (int intDeltaY = -(int)(fltRadiusMax); intDeltaY <= (int)(fltRadiusMax); ++intDeltaY) {
for (int intDeltaX = -(int)(fltRadiusMax); intDeltaX <= (int)(fltRadiusMax); ++intDeltaX) {
int intNeighborY = intY + intDeltaY;
int intNeighborX = intX + intDeltaX;
if ((intNeighborY >= 0) && (intNeighborY < SIZE_2(error)) && (intNeighborX >= 0) && (intNeighborX < SIZE_3(error))) {
float fltDist = sqrtf((float)(intDeltaY)*(float)(intDeltaY) + (float)(intDeltaX)*(float)(intDeltaX));
if (fltDist < fltRadiusMax) {
float alpha = fltDist / fltRadiusMax;
float beta = fltRadiusMin / fltRadiusMax;
// float fltError = (1 - powf(alpha, delta1)) * (1 - powf(beta, delta2));
// float fltError = (1 - powf(alpha, delta1)) * (delta2 > beta); // (0.5 + 0.5 * tanhf(50 * (delta2 - beta)));
// float fltError = (delta1 > alpha) * (delta2 > beta); // (0.5 + 0.5 * tanhf(50 * (delta2 - beta)));
float fltError = (0.5 + 0.5 * tanhf(20 * (delta1 - alpha))) * (0.5 + 0.5 * tanhf(20 * (delta2 - beta)));
atomicMax(&error[OFFSET_4(error, intN, 0, intNeighborY, intNeighborX)], int(fltError * 1e8));
}
}
}
}
}
}
'''
def cupy_kernel(strFunction, objVariables):
strKernel = globals()[strFunction]
while True:
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
# end
while True:
objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
if objMatch is None:
break
# end
intArgs = int(objMatch.group(2))
strArgs = objMatch.group(4).split(',')
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
intStrides[intArg]) + ')' for intArg in range(intArgs)]
strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
# end
while True:
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
if objMatch is None:
break
# end
intArgs = int(objMatch.group(2))
strArgs = objMatch.group(4).split(',')
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
intStrides[intArg]) + ')' for intArg in range(intArgs)]
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
# end
return strKernel
# end
# @cupy.util.memoize(for_each_device=True)
@cupy.memoize(for_each_device=True)
def cupy_launch(strFunction, strKernel):
return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
# end
class _FunctionRender(torch.autograd.Function):
@staticmethod
def forward(self, radius_max, radius_min, delta1, delta2):
# self.save_for_backward()
threshold1 = min(radius_min.shape[2], radius_min.shape[3]) / 1000
threshold2 = 0.9
error = torch.zeros_like(radius_max, dtype=torch.int)
if error.is_cuda == True:
n = error.nelement()
cupy_launch('kernel_Render_updateOutput', cupy_kernel('kernel_Render_updateOutput', {
'delta1': delta1,
'delta2': delta2,
'threshold1': threshold1,
'threshold2': threshold2,
'radius_max': radius_max,
'radius_min': radius_min,
'error': error,
}))(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cupy.int(n),
cupy.float32(delta1),
cupy.float32(delta2),
cupy.float32(threshold1),
cupy.float32(threshold2),
radius_max.data_ptr(),
radius_min.data_ptr(),
error.data_ptr(),
]
)
elif error.is_cuda == False:
raise NotImplementedError()
# end
return error.float() / 1e8
# end
# @staticmethod
# def backward(self, gradBokehCum, gradWeightCum):
# end
# end
def FunctionRender(radius_max, radius_min, delta1, delta2):
error = _FunctionRender.apply(radius_max, radius_min, delta1, delta2)
return error
# end
class ModuleGenError(torch.nn.Module):
def __init__(self):
super(ModuleGenError, self).__init__()
# end
def forward(self, defocus, delta1, delta2, short_size=384):
b, _, h, w = defocus.shape
if short_size:
h_re = int(round(min(h, max(short_size, short_size * h / w))))
w_re = int(round(min(w, max(short_size, short_size * w / h))))
scale = (h * w / h_re / w_re) ** 0.5
defocus = 1/scale * F.interpolate(defocus, size=(h_re, w_re), mode='bilinear', align_corners=True)
else:
h_re = h
w_re = w
radius = defocus.abs()
size = 2
radius = F.pad(radius, pad=(size, size, size, size), mode='replicate')
radius = F.unfold(radius, kernel_size=2*size+1)
radius = radius.reshape(b, -1, h_re, w_re)
radius_max = radius.max(dim=1, keepdim=True)[0]
radius_min = radius.min(dim=1, keepdim=True)[0]
error = FunctionRender(radius_max, radius_min, delta1, delta2)
error = gaussian_blur(error, 3)
if short_size:
error = F.interpolate(error, size=(h, w), mode='bilinear', align_corners=True)
return error
# end
# end
if __name__ == '__main__':
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
defocus = 20 * torch.rand(1, 1, 1080, 1920).cuda()
module = ModuleGenError().cuda()
error = module(defocus, delta1=0.9, delta2=0.8, short_size=384)
print(error)
OK. Thank you very much!!!
Hello! I had some problems while training arnet and would like to ask you for some advice. Firstly, ARNet's losses are falling but fluctuating a bit. The loss function does not contain the error map loss term. Another problem is that the bokeh image from ARNet‘s output has colour errors. I don't know if you have encountered these problems and would appreciate some ideas on how to solve them.
Maybe the model needs more time to train. Also, our training code is based on DeepFocus. You can check if there is something wrong.
Hello! In the training dataset, bokehme_syn_data, you provided, do I need to normalize the disparity after reading 'disparity.exr'? That is, does the second step below require?
No, you don’t need to do that.
OK,thank you for your reply!
hi, your work is so great! I have downloaded the traindataset and found that the disparity.exr is 512 3, since bokeh_gt and image is 512 512 *3. Could you please give a guidance on how to use it? Thank you very much!
You can try disp = cv2.imread(disp_path, -1).astype(np.float32).
You can try disp = cv2.imread(disp_path, -1).astype(np.float32).
Thanks a lot!!
@wzfsjtu I trained the model using my own implementation of the training code, but it didn't work very well. I wonder if you have trained the model and how well it works.
@JuewenPeng Hi! I have a question. If I want to be using a DSLR to create the dataset, how should I determine the parameter K and parameter gamma corresponding to each bokeh image taken by the DSLR camera?
I think it is really hard. Besides, it is hard to obtain the disparity map, and captured all-in-focus images and bokeh images exist color inconsistency and misalignment, such as EBB! dataset.
OK, I got it. Thank you!
---Original--- From: @.> Date: Fri, Sep 16, 2022 20:25 PM To: @.>; Cc: @.**@.>; Subject: Re: [JuewenPeng/BokehMe] code of training model (Issue #4)
I think it is really hard. Besides, it is hard to obtain the disparity map, and captured all-in-focus images and bokeh images exist color inconsistency and misalignment, such as EBB! dataset.
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>
@JuewenPeng Hi,brother! Have you carried out any pre-processing and post-processing operations to further improve the results? Could you provide the relevant codes? The results I have obtained, both from my individual replication and from using the model you provided, are still a bit short of the metrics in your paper. So I would like to ask you what pre-processing and post-processing operations can be used to further improve the results.
Do you mean the evaluation on the BLB dataset?
The pretrained model we provide in this repository is different from that in the original paper, but in our experiment, this one is much better. You can use the following code to evaluate the model on the BLB dataset. Remember to change the data path to your own.
# NOTE
# In the BLB dataset, the maximum values of the all-in-focus image and the bokeh ground truth may be larger than 1.
# The numerical ranges of the predicted bokeh image and the ground truth will be clipped in [0, 1] before evaluation.
import os
import cv2
import numpy as np
import time
import xlwt
import json
import warnings
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
############# import your model #############
from neural_renderer import ARNet, IUNet
from classical_renderer.scatter import ModuleRenderScatter # circular aperture
#############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def calc_psnr(pred, gt, data_range=1):
if data_range == 1:
pred = pred * 255
gt = gt * 255
mse = torch.mean((pred - gt) ** 2)
if mse == 0:
return float('inf')
else:
return 20 * torch.log10(255.0 / torch.sqrt(mse)).item()
def calc_ssim(X, Y, mask=None, data_range=1, size_average=True, win_size=11, win_sigma=1.5, win=None, K=(0.01, 0.03), nonnegative_ssim=False):
r""" interface of ssim
Args:
X (torch.Tensor): a batch of images, (N,C,H,W)
Y (torch.Tensor): a batch of images, (N,C,H,W)
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
Returns:
torch.Tensor: ssim results
"""
if not X.shape == Y.shape:
raise ValueError("Input images should have the same dimensions.")
for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
if len(X.shape) not in (4, 5):
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
if not X.type() == Y.type():
raise ValueError("Input images should have the same dtype.")
if win is not None: # set win_size
win_size = win.shape[-1]
if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
ssim_per_channel, cs = _ssim(X, Y, mask=mask, data_range=data_range, win=win, size_average=False, K=K)
if nonnegative_ssim:
ssim_per_channel = torch.relu(ssim_per_channel)
if size_average:
return ssim_per_channel.mean()
else:
return ssim_per_channel.mean(1)
def _fspecial_gauss_1d(size, sigma):
r"""Create 1-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 1D kernel (1 x 1 x size)
"""
coords = torch.arange(size).to(dtype=torch.float)
coords -= size // 2
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
g /= g.sum()
return g.unsqueeze(0).unsqueeze(0)
def gaussian_filter(input, win):
r""" Blur input with 1-D kernel
Args:
input (torch.Tensor): a batch of tensors to be blurred
window (torch.Tensor): 1-D gauss kernel
Returns:
torch.Tensor: blurred tensors
"""
assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
if len(input.shape) == 4:
conv = F.conv2d
elif len(input.shape) == 5:
conv = F.conv3d
else:
raise NotImplementedError(input.shape)
C = input.shape[1]
out = input
for i, s in enumerate(input.shape[2:]):
if s >= win.shape[-1]:
out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
else:
warnings.warn(
f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
)
return out
def _ssim(X, Y, mask, data_range, win, size_average=True, K=(0.01, 0.03)):
r""" Calculate ssim index for X and Y
Args:
X (torch.Tensor): images
Y (torch.Tensor): images
win (torch.Tensor): 1-D gauss kernel
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
Returns:
torch.Tensor: ssim results.
"""
K1, K2 = K
# batch, channel, [depth,] height, width = X.shape
compensation = 1.0
C1 = (K1 * data_range) ** 2
C2 = (K2 * data_range) ** 2
win = win.to(X.device, dtype=X.dtype)
mu1 = gaussian_filter(X, win)
mu2 = gaussian_filter(Y, win)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
if mask == None:
ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
cs = torch.flatten(cs_map, 2).mean(-1)
else:
crop_size = int(win.shape[3] // 2)
mask = mask[:, :, crop_size:-crop_size, crop_size:-crop_size]
ssim_per_channel = torch.flatten(ssim_map * mask, 2).mean(-1) / torch.flatten(mask, 2).mean(-1)
cs = torch.flatten(cs_map * mask, 2).mean(-1) / torch.flatten(mask, 2).mean(-1)
return ssim_per_channel, cs
def style(bold=False, underline=False, italic=False, auto_warp=True, align_h='center', align_v='center'):
style = xlwt.XFStyle()
font = xlwt.Font()
font.bold = bold
font.underline = underline
font.italic = italic
style.font = font
alignment = xlwt.Alignment()
if align_h == 'center':
alignment.horz = xlwt.Alignment.HORZ_CENTER
elif align_h == 'left':
alignment.horz = xlwt.Alignment.HORZ_LEFT
elif align_h == 'right':
alignment.horz = xlwt.Alignment.HORZ_RIGHT
else:
print('error')
exit(0)
if align_v == 'center':
alignment.vert = xlwt.Alignment.VERT_CENTER
elif align_v == 'top':
alignment.vert = xlwt.Alignment.VERT_TOP
elif align_v == 'bottom':
alignment.vert = xlwt.Alignment.VERT_BOTTOM
else:
print('error')
exit(0)
alignment.wrap = int(auto_warp)
style.alignment = alignment
return style
########################## change to your rendering pipeline ##########################
def gaussian_blur(x, r, sigma=None):
r = int(round(r))
if sigma is None:
sigma = 0.3 * (r - 1) + 0.8
x_grid, y_grid = torch.meshgrid(torch.arange(-int(r), int(r) + 1), torch.arange(-int(r), int(r) + 1))
kernel = torch.exp(-(x_grid ** 2 + y_grid ** 2) / 2 / sigma ** 2)
kernel = kernel.float() / kernel.sum()
kernel = kernel.expand(1, 1, 2*r+1, 2*r+1).to(x.device)
x = F.pad(x, pad=(r, r, r, r), mode='replicate')
x = F.conv2d(x, weight=kernel, padding=0)
return x
def pipeline(classical_renderer, arnet, iunet, image, defocus, gamma, args):
bokeh_classical, defocus_dilate = classical_renderer(image**gamma, defocus*args.defocus_scale)
bokeh_classical = bokeh_classical ** (1/gamma)
defocus_dilate = defocus_dilate / args.defocus_scale
gamma = (gamma - args.gamma_min) / (args.gamma_max - args.gamma_min)
adapt_scale = max(defocus.abs().max().item(), 1)
image_re = F.interpolate(image, scale_factor=1/adapt_scale, mode='bilinear', align_corners=True)
defocus_re = 1 / adapt_scale * F.interpolate(defocus, scale_factor=1/adapt_scale, mode='bilinear', align_corners=True)
bokeh_neural, error_map = arnet(image_re, defocus_re, gamma)
error_map = F.interpolate(error_map, size=(image.shape[2], image.shape[3]), mode='bilinear', align_corners=True)
bokeh_neural.clamp_(0, 1e5)
for scale in range(int(np.log2(adapt_scale))):
ratio = 2**(scale+1) / adapt_scale
h_re, w_re = int(ratio * image.shape[2]), int(ratio * image.shape[3])
image_re = F.interpolate(image, size=(h_re, w_re), mode='bilinear', align_corners=True)
defocus_re = ratio * F.interpolate(defocus, size=(h_re, w_re), mode='bilinear', align_corners=True)
defocus_dilate_re = ratio * F.interpolate(defocus_dilate, size=(h_re, w_re), mode='bilinear', align_corners=True)
bokeh_neural_refine = iunet(image_re, defocus_re.clamp(-1, 1), bokeh_neural, gamma).clamp(0, 1e5)
mask = gaussian_blur(((defocus_dilate_re < 1) * (defocus_dilate_re > -1)).float(), 0.005 * (defocus_dilate_re.shape[2] + defocus_dilate_re.shape[3]))
bokeh_neural = mask * bokeh_neural_refine + (1 - mask) * F.interpolate(bokeh_neural, size=(h_re, w_re), mode='bilinear', align_corners=True)
bokeh_neural_refine = iunet(image, defocus.clamp(-1, 1), bokeh_neural, gamma).clamp(0, 1e5)
mask = gaussian_blur(((defocus_dilate < 1) * (defocus_dilate > -1)).float(), 0.005 * (defocus_dilate.shape[2] + defocus_dilate.shape[3]))
bokeh_neural = mask * bokeh_neural_refine + (1 - mask) * F.interpolate(bokeh_neural, size=(image.shape[2], image.shape[3]), mode='bilinear', align_corners=True)
bokeh_pred = bokeh_classical * (1 - error_map) + bokeh_neural * error_map
return bokeh_pred.clamp(0, 1), bokeh_classical.clamp(0, 1), bokeh_neural.clamp(0, 1), error_map
#######################################################################################
def main():
gamma = 2.2
############################# change to your settings #############################
method = 'BokehMe' # method name
root = '/data2/pengjuewen/Bokeh/Blender/data' # path to BLB dataset
save_root = os.path.join('./BLB', method) # path to save
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser(description='Bokeh Rendering', fromfile_prefix_chars='@')
parser.add_argument('--defocus_scale', type=float, default=10.)
parser.add_argument('--gamma_min', type=float, default=1.)
parser.add_argument('--gamma_max', type=float, default=5.)
# Model 1
parser.add_argument('--arnet_shuffle_rate', type=int, default=2)
parser.add_argument('--arnet_in_channels', type=int, default=5)
parser.add_argument('--arnet_out_channels', type=int, default=4)
parser.add_argument('--arnet_middle_channels', type=int, default=128)
parser.add_argument('--arnet_num_block', type=int, default=3)
parser.add_argument('--arnet_share_weight', action='store_true')
parser.add_argument('--arnet_connect_mode', type=str, default='distinct_source')
parser.add_argument('--arnet_use_bn', action='store_true')
parser.add_argument('--arnet_activation', type=str, default='elu')
# Model 2
parser.add_argument('--iunet_shuffle_rate', type=int, default=2)
parser.add_argument('--iunet_in_channels', type=int, default=8)
parser.add_argument('--iunet_out_channels', type=int, default=3)
parser.add_argument('--iunet_middle_channels', type=int, default=64)
parser.add_argument('--iunet_num_block', type=int, default=3)
parser.add_argument('--iunet_share_weight', action='store_true')
parser.add_argument('--iunet_connect_mode', type=str, default='distinct_source')
parser.add_argument('--iunet_use_bn', action='store_true')
parser.add_argument('--iunet_activation', type=str, default='elu')
# Checkpoint
parser.add_argument('--arnet_checkpoint_path', type=str, default='./checkpoints/arnet.pth')
parser.add_argument('--iunet_checkpoint_path', type=str, default='./checkpoints/iunet.pth')
# Input
args = parser.parse_args()
arnet_checkpoint_path = args.arnet_checkpoint_path
iunet_checkpoint_path = args.iunet_checkpoint_path
classical_renderer = ModuleRenderScatter().to(device)
arnet = ARNet(args.arnet_shuffle_rate, args.arnet_in_channels, args.arnet_out_channels, args.arnet_middle_channels,
args.arnet_num_block, args.arnet_share_weight, args.arnet_connect_mode, args.arnet_use_bn,
args.arnet_activation)
iunet = IUNet(args.iunet_shuffle_rate, args.iunet_in_channels, args.iunet_out_channels, args.iunet_middle_channels,
args.iunet_num_block, args.iunet_share_weight, args.iunet_connect_mode, args.iunet_use_bn,
args.iunet_activation)
arnet.cuda()
iunet.cuda()
checkpoint = torch.load(arnet_checkpoint_path)
arnet.load_state_dict(checkpoint['model'])
checkpoint = torch.load(iunet_checkpoint_path)
iunet.load_state_dict(checkpoint['model'])
arnet.eval()
iunet.eval()
###################################################################################
os.makedirs(save_root, exist_ok=True)
scene_lst = [name for name in sorted(os.listdir(root)) if '.' not in name]
disp_focus_lst = [f'({i+1})' for i in range(10)]
metric_lst = ['psnr', 'ssim', 'runtime']
scene_num = len(scene_lst)
disp_focus_num = len(disp_focus_lst)
metric_num = len(metric_lst)
with torch.no_grad():
for K_idx in range(5):
scene_metric_avg = np.zeros([metric_num, scene_num])
disp_focus_metric_avg = np.zeros([metric_num, disp_focus_num])
# initialize excel
workbook = xlwt.Workbook(encoding='utf-8')
worksheets = [workbook.add_sheet(ind, cell_overwrite_ok=True) for ind in metric_lst]
standard_style = style()
left_style = style(align_h='left')
for worksheet in worksheets:
worksheet.write_merge(0, 1, 0, 1, method, style=standard_style)
worksheet.write_merge(0, 0, 2, 1 + disp_focus_num, 'Refocused Disparity', style=standard_style)
for i, disp_focus in enumerate(disp_focus_lst):
worksheet.write(1, 2 + i, disp_focus, style=standard_style)
worksheet.write_merge(2, 1 + scene_num, 0, 0, 'Scene', style=standard_style)
worksheet.write_merge(2 + scene_num, 2 + scene_num, 0, 1, 'Average', style=standard_style)
worksheet.write_merge(0, 1, 2 + disp_focus_num, 2 + disp_focus_num, 'Average', style=standard_style)
for scene_idx in range(scene_num):
scene_name = scene_lst[scene_idx]
scene_path = os.path.join(root, scene_name)
save_scene_path = os.path.join(save_root, scene_name)
os.makedirs(save_scene_path, exist_ok=True)
image = cv2.imread(os.path.join(scene_path, 'image.exr'), -1)[..., :3].astype(np.float32) ** (1/gamma)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Input RGB and output RGB by default
###### change to the name of corrupted depth map if necessary ######
depth_name = 'depth.exr'
####################################################################
depth = cv2.imread(os.path.join(scene_path, depth_name), -1)[..., 0].astype(np.float32)
disp = 1 / depth
############# comment it if using tensorflow model #############
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
disp = torch.from_numpy(disp).unsqueeze(0).unsqueeze(0).contiguous().to(device)
################################################################
for metric_idx in range(metric_num):
worksheets[metric_idx].write(2+scene_idx, 1, scene_name, style=standard_style)
file = open(os.path.join(scene_path, 'info.json'), 'r')
info_data = json.load(file)
Ks = info_data['blur_parameters']
focus_distances = info_data['focus_distances']
for df_idx in range(len(disp_focus_lst)):
K = Ks[K_idx]
disp_focus = 1 / focus_distances[df_idx]
defocus = K * (disp - disp_focus) / args.defocus_scale
gt_name = f'bokeh_{K_idx:0>2d}_{df_idx:0>2d}.exr'
gt = cv2.imread(os.path.join(scene_path, gt_name), -1)[..., :3].astype(np.float32) ** (1/gamma)
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
gt = torch.from_numpy(gt).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
# neglect the runtime of the first inference
if scene_idx + df_idx == 0:
pipeline(classical_renderer, arnet, iunet, image, defocus, gamma, args)
###### comment it if using pytorch cpu or tensorflow model ######
torch.cuda.synchronize()
#################################################################
start = time.time()
pred = pipeline(classical_renderer, arnet, iunet, image, defocus, gamma, args)[0]
###### comment it if using pytorch cpu or tensorflow model ######
torch.cuda.synchronize()
#################################################################
end = time.time()
############ uncomment it if using tensorflow model ############
# pred = torch.from_numpy(pred).to(device)
################################################################
# evaluation
psnr = calc_psnr(pred.clamp(0, 1), gt.clamp(0, 1))
ssim = calc_ssim(pred.clamp(0, 1), gt.clamp(0, 1))
# save results
pred = pred[0].cpu().clone().permute(1, 2, 0).numpy()
pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
save_name = f'bokeh_{K_idx:0>2d}_{df_idx:0>2d}.jpg'
cv2.imwrite(os.path.join(save_scene_path, save_name), pred * 255)
runtime = end - start
scene_metric_avg[0, scene_idx] += psnr
scene_metric_avg[1, scene_idx] += ssim
scene_metric_avg[2, scene_idx] += runtime
disp_focus_metric_avg[0, df_idx] += psnr
disp_focus_metric_avg[1, df_idx] += ssim
disp_focus_metric_avg[2, df_idx] += runtime
# write to excel
ii = scene_idx + 2
jj = df_idx + 2
worksheets[0].write(ii, jj, float(psnr), style=left_style)
worksheets[1].write(ii, jj, float(ssim), style=left_style)
worksheets[2].write(ii, jj, float(runtime), style=left_style)
print(f'scene[{scene_idx+1}/{scene_num}] disp_focus[{df_idx+1}/{disp_focus_num}] '
f'PSNR:{psnr} SSIM:{ssim} Runtime: {runtime}')
scene_metric_avg /= disp_focus_num
disp_focus_metric_avg /= scene_num
assert np.abs(scene_metric_avg.mean(axis=1)[0] - disp_focus_metric_avg.mean(axis=1)[0]) < 1e-5
metric_avg = scene_metric_avg.mean(axis=1)
for scene_idx in range(scene_num):
for df_idx in range(disp_focus_num):
ii = scene_idx + 2
jj = df_idx + 2
for metric_idx in range(metric_num):
worksheets[metric_idx].write(ii, 2+disp_focus_num, float(scene_metric_avg[metric_idx, scene_idx]), style=left_style)
worksheets[metric_idx].write(2+scene_num, jj, float(disp_focus_metric_avg[metric_idx, df_idx]), style=left_style)
for metric_idx in range(metric_num):
worksheets[metric_idx].write(2+scene_num, 2+disp_focus_num, float(metric_avg[metric_idx]), style=left_style)
xls_name = f'evaluation_K={int(10*(K_idx+1))}.xls'
workbook.save(os.path.join(save_root, xls_name))
print(f'"{method}" evaluation done!')
if __name__ == '__main__':
main()
OK, thank you very much!
@JuewenPeng Hi! I have two questions about model training. First, since I see that you did not normalize ‘image.exr’ and 'bokeh.exr' after reading them when evaluating BLB, I would like to ask if I need to normalize ‘image.exr’ and 'bokeh.exr' after reading them when training the model. The second query is whether the numerical ranges of the predicted bokeh image and the ground truth will be clipped in [0, 1] before calculating the loss. Looking forward to your answers!
@JuewenPeng Hi! Could you provide the test dataset EBB400 containing the corresponding disparity map and the refocused disparity? I would be very grateful!
@JuewenPeng Hi! I have two questions about model training. First, since I see that you did not normalize ‘image.exr’ and 'bokeh.exr' after reading them when evaluating BLB, I would like to ask if I need to normalize ‘image.exr’ and 'bokeh.exr' after reading them when training the model. The second query is whether the numerical ranges of the predicted bokeh image and the ground truth will be clipped in [0, 1] before calculating the loss. Looking forward to your answers!
For the first question, you don't need to normalize the images during the training. Values of all RGB images are supposed to be in the range of [0, 1.5]. For the second question, we clip the values of bokeh images in [0, 1] for all methods since some of them cannot output values out of [0, 1] (our methods can do that), so we think this clipping operation can make the comparisons fairer.
@JuewenPeng Hi! Could you provide the test dataset EBB400 containing the corresponding disparity map and the refocused disparity? I would be very grateful!
Do you really need that? Honestly, I don't think this experiment is totally fair and necessary since there is color inconsistency and misalignment between the pairs of input image and bokeh image. We do it just to make the whole experiment more complete.
@JuewenPeng Hi! Could you provide the test dataset EBB400 containing the corresponding disparity map and the refocused disparity? I would be very grateful!
Do you really need that? Honestly, I don't think this experiment is really fair and necessary. Just to make the experiment more complete.
Here's the thing, I want to see the bokeh effect in a real scene. The EBB400 is just right for me for this need. However, generating disparity maps and determining refocused disparity values is a time-consuming task, and it would save me a lot of work if you could provide readily available data.
OK, let me upload it to the Baidu Netdisk.
@JuewenPeng Hi! I have two questions about model training. First, since I see that you did not normalize ‘image.exr’ and 'bokeh.exr' after reading them when evaluating BLB, I would like to ask if I need to normalize ‘image.exr’ and 'bokeh.exr' after reading them when training the model. The second query is whether the numerical ranges of the predicted bokeh image and the ground truth will be clipped in [0, 1] before calculating the loss. Looking forward to your answers!
For the first question, you don't need to normalize the images during the training. Values of all RGB images are supposed to be in the range of [0, 1.5]. For the second question, we clip the values of bokeh images in [0, 1] for all methods since some of them cannot output values out of [0, 1] (our methods can do that), so we think this clipping operation can make the comparisons fairer.
Thank you very much for your answers. For the second question, I would like to know if it is necessary to clip the numerical ranges of the predicted bokeh and the gt bokeh before calculating the loss when training the model.
OK, let me upload it to the Baidu Netdisk.
Thank you very much!!!
I think it's optional, but in my practice, I didn't clip the predicted values during the training.
I think it's optional, but in my practice, I didn't clip the predicted values during the training. OK, thank you very much!
I want to know for the pretrained model you provide in this repository, did you train it using only the train dataset in bokehme_syn_data, and did you add some data from BLB? I use only the train dataset in bokehme_syn_data to train the model, and when testing the BLB dataset, the SSIM metric can only reach 0.97. I wonder if there is anything I need to pay attention to in the training process
We only trained our model on the synthetic dataset.
EBB400 Baidu Netdisk: https://pan.baidu.com/s/1l3Rug16HEB2uUi3u366vLw?pwd=f7mp
We conduct our experiment using the disparity maps in disparity
directory, all of which are predicted by MiDaS. We also provide the disparity maps predicted by DPT in disparity_dpt
directory. You can use them if expecting better bokeh rendering effects.
OK, thank you very much!
EBB400 Baidu Netdisk: https://pan.baidu.com/s/1l3Rug16HEB2uUi3u366vLw?pwd=f7mp
We conduct our experiment using the disparity maps in
disparity
directory, all of which are predicted by MiDaS. We also provide the disparity maps predicted by DPT indisparity_dpt
directory. You can use them if expecting better bokeh rendering effects.
thanks for the ebb400. and there is another question. I can't download the entire EBB training set because the website [(https://competitions.codalab.org/competitions/24716#participate) is not valid. Do you have a backup here? Can you share it
Sorry for that, but you'd better first register for the competition and then download the entire dataset.
Sorry for that, but you'd better first register for the competition and then download the entire dataset.
But it seems that i can't register for the competition because it ends.
I remember that one can register for the competition any time.
I remember that one can register for the competition any time.
this ?
@JuewenPeng Hi, could you provide the code for training the model, I would be very grateful!