linfengWen98 / CAP-VSTNet

[CVPR 2023] CAP-VSTNet: Content Affinity Preserved Versatile Style Transfer
MIT License
120 stars 8 forks source link

code of calculating of temporal loss #11

Open Rorschach-xyz opened 11 months ago

Rorschach-xyz commented 11 months ago

Could you share the code of calculating of temporal loss?

linfengWen98 commented 11 months ago

1. Prepare the optical flow files (.flo)

Note: The code below is messy, please refer to corresponding link for some variables and functions (just copy them from the link).

① Using RAFT to estimate flow

def load_image(imfile, DEVICE='cuda', max_size=1280):
    img = Image.open(imfile).convert('RGB')
    w, h = img.size
    if max(w, h) > max_size:
        w = int(1.0 * img.size[0] / max(img.size) * max_size)
        h = int(1.0 * img.size[1] / max(img.size) * max_size)
        img = img.resize((w, h), Image.BICUBIC)

    img = np.array(img).astype(np.uint8)

    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(DEVICE)

model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to('cuda')
model.eval()

frame_list = [path1, path2, ...]   # List should be ordered, path should be input video frame (.png)
interval = 1  # 10
for i in range(len(frame_list) - interval):
    cn = os.path.basename(frame_list[i]).split('.')[0]
    arguments_strOut = os.path.join(directory, cn + '.flo')

    FirstFrame = load_image(frame_list[i])
    SecondFrame = load_image(frame_list[i + interval])

    padder = InputPadder(FirstFrame.shape)
    FirstFrame, SecondFrame = padder.pad(FirstFrame, SecondFrame)

    with torch.no_grad():
        flow_low, flow_up = model(FirstFrame, SecondFrame, iters=20, test_mode=True)
    flow = padder.unpad(flow_up[0])

    # # save .flo
    objOutput = open(arguments_strOut, 'wb')
    np.array([80, 73, 69, 72], np.uint8).tofile(objOutput)
    np.array([flow.shape[2], flow.shape[1]], np.int32).tofile(objOutput)
    np.array(flow.cpu().numpy().transpose(1, 2, 0), np.float32).tofile(objOutput)
    objOutput.close()

② Using PWC-Net to estimate flow

frame_list = [path1, path2, ...]   # List should be ordered, path should be input video frame (.png)
interval = 1  # 10
for i in range(len(frame_list) - interval):
    cn = os.path.basename(frame_list[i]).split('.')[0]
    arguments_strOut = os.path.join(directory, cn + '.flo')
    print(video_name, frame_list[i], frame_list[i + interval], arguments_strOut)

    FirstFrame = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(frame_list[i]))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
    SecondFrame = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(frame_list[i + interval]))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))

    tenOutput = estimate(FirstFrame, SecondFrame)

    objOutput = open(arguments_strOut, 'wb')
    numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)
    numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)
    numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)
    objOutput.close()

2. Calculate the temporal loss

def warp(img, flow):
    h, w = flow.shape[:2]
    flow = -flow
    flow[:, :, 0] += np.arange(w)
    flow[:, :, 1] += np.arange(h)[:, np.newaxis]
    res = cv2.remap(img.astype(np.float32), flow.astype(np.float32), None, cv2.INTER_LINEAR)
    return res

def MSE(A, B):
    return (np.square(A - B)).mean()

frame_list = [path1, path2, ...]   # List should be ordered,  path should be stylized video frame (.png)
forward_flow_list = [path1, path2, ...]   # List should be ordered,  path should be flow file (.flo) estimated on content video frame
interval = 1  # 10
for i in range(len(frame_list) - interval):
    FirstFrame = cv2.imread(frame_list[i])
    SecondFrame = cv2.imread(frame_list[i + interval])

    forward_flow = cv2.readOpticalFlow(forward_flow_list[i])
    warpped_pre_frame = warp(SecondFrame, -forward_flow)

    # mask = cv2.imread(mask_list[i])
    # mask = 1-mask / 255.
    # mask *= warp(np.ones([H, W, 3]), -forward_flow)
    # warpped_pre_frame = warpped_pre_frame*mask
    # FirstFrame = FirstFrame*mask

    temporal_loss = MSE(warpped_pre_frame / 255., FirstFrame / 255.) ** 0.5
    temporal_loss_list.append(temporal_loss)

Note: The frame_list and forward_flow_list must have the same order. If you want to use the mask, ReReVST shows how to calculate it.

1621950900 commented 10 months ago

Hello, can you share the code of temporal error heatmap?

linfengWen98 commented 10 months ago

3. Calculate the temporal error heatmap

import cv2
import numpy as np
import os

import matplotlib.pylab as plt
import seaborn as sns
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image

def warp(img, flow):
    h, w = flow.shape[:2]
    flow = -flow
    flow[:, :, 0] += np.arange(w)
    flow[:, :, 1] += np.arange(h)[:, np.newaxis]
    res = cv2.remap(img.astype(np.float32), flow.astype(np.float32), None, cv2.INTER_LINEAR)
    return res

def heatmap(path1, path2, forward_flow, mask=None):
    FirstFrame = cv2.imread(path1)/255.
    SecondFrame = cv2.imread(path2)/255.

    forward_flow = cv2.readOpticalFlow(forward_flow)
    warpped_pre_frame = warp(SecondFrame, -forward_flow)

    if mask is not None:
        mask = cv2.imread(mask)
        mask = 1-mask / 255.
        H, W = FirstFrame.shape[0], FirstFrame.shape[1]
        mask *= warp(np.ones([H, W, 3]), -forward_flow)
        warpped_pre_frame = warpped_pre_frame*mask
        FirstFrame = FirstFrame*mask

    # calculate temporal error
    error = abs(FirstFrame - warpped_pre_frame)
    error = np.mean(error, axis=2, keepdims=True)

    # set figure
    figsize = (int(error.shape[1]//100)+1, int(error.shape[0]//100)+1)
    fig = plt.figure(figsize=figsize)   # figsize=[width, height] in inches. num_pixel=figsize*100
    plt.axis('off')
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, hspace=0, wspace=0)
    plt.margins(0, 0)

    # draw heatmap
    sns.heatmap(error[:, :, 0], cmap="hot", cbar=False)

    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    buf = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8)

    w, h = fig.canvas.get_width_height()
    buf.shape = (w, h, 4)
    buf = np.roll(buf, 3, axis=2)
    image = Image.frombytes("RGBA", (w, h), buf.tobytes())
    image = image.resize((error.shape[1], error.shape[0]), Image.NEAREST)

    image = np.asarray(image)[:, :, :3]
    image = image[:, :, ::-1]

    plt.close()
    # cv2.imwrite(file_path, image)
    return image

Note: If you want to use the mask, ReReVST shows how to calculate it.

1621950900 commented 10 months ago

Thank you very much