Justin62628 / Squirrel-RIFE

效果更好的补帧软件,显存占用更小,是DAIN速度的10-25倍,包含抽帧处理,去除动漫卡顿感
GNU General Public License v3.0
3.11k stars 176 forks source link

减小不支持任意时刻的模型在处理任意帧数时的推理次数 #528

Closed routineLife1 closed 1 year ago

routineLife1 commented 1 year ago

问题描述: 在img0和img1之间补4帧,先前的方法是先经过三次二分迭代补帧得到7帧,之后抛弃多余的三帧,这样做非常浪费算力,且得到的时刻点0.125 0.375 0.625 0.875并不符合邻近的理论。

解决方案: 在进行最后一次迭代前, 先计算时间点, 算出用哪些帧补出剩余需要的帧较合适, 这样就能减小不支持任意时刻的模型在处理任意帧数时的推理次数, 并纠正时刻错误。

例: 补4帧时,二分推理补帧两次后得到了0.25 0.5 0.75时刻对应的中间帧, 下一次迭代将获得0.125 0.375 0.625 0.875四个时刻对应的中间帧, 0到1之间取四帧, 为了连贯性考虑, 得到的帧需要尽可能与0.2 0.4 0.6 0.8时刻接近,从得到的时刻中选择最邻近的则应该选择0.25 0.375 0.625 0.75, 则可确定最后一次迭代时不需要inference(0, 0.25)和inference(0.75, 1)的操作

routineLife1 commented 1 year ago

测试代码

import numpy as np

cnt = 0

def model(x, y):
    return (x + y) / 2

def make_inference(x, y, _n):
    """先前方法"""
    mid = model(x, y)
    global cnt
    cnt += 1
    if _n == 1:
        return [mid]
    first_half = make_inference(x, mid, _n=_n // 2)
    second_half = make_inference(mid, y, _n=_n // 2)
    if _n % 2:
        return [*first_half, mid, *second_half]  # 这里应该既返回帧,也返回对应时刻
    else:
        return [*first_half, *second_half]  # 这里应该既返回帧,也返回对应时刻

def make_inference_timestamp(x, y, _n):
    """注意该函数只生成时刻"""
    mid = model(x, y)
    if _n == 1:
        return [mid]
    first_half = make_inference_timestamp(x, mid, _n=_n // 2)
    second_half = make_inference_timestamp(mid, y, _n=_n // 2)
    if _n % 2:
        return [*first_half, mid, *second_half]
    else:
        return [*first_half, *second_half]

# 得到要补帧数对应的上下界, 例如4帧的下界为3, 上界为7
def get(x):
    __n = 0
    a = np.power(2, __n)
    b = a + np.power(2, __n + 1)
    while not (a <= x <= b):
        __n += 1
        a += np.power(2, __n)
        b = a + np.power(2, __n + 1)
    return a, b

def make_inference2(x, y, _n):
    global cnt
    cnt = 0
    low, high = get(_n)
    last = make_inference(x, y, low)  # 因为传入的是边界, 所以会导出所有时刻点
    next = make_inference_timestamp(x, y, high)  # 因为传入的是边界, 所以会导出所有时刻点

    std_timestamp = [i * 1 / (_n + 1) for i in range(1, _n + 1)]  # 需要的标准时刻点
    arg_timestamp = []  # 模型推理能提供的最邻近时刻点(变量名乱起的)

    # 求最邻近时刻点
    for i in range(len(std_timestamp)):
        nd_array = np.abs(np.array([std_timestamp[i]] * len(next)) - np.array(next))
        arg_timestamp.append(next[np.argmin(nd_array)])

    for a in arg_timestamp:
        if a in next and a not in last:
            cnt += 1  # 实现时在这里进行剩下的补帧操作
    return std_timestamp, arg_timestamp, cnt
    # return cnt

# for n in range(1, 24):
#     make_inference(0, 1, n)
#     x = cnt
#     y = make_inference2(0, 1, n)
#     cnt = 0
#     print(f'需要的帧数:{n} 加速百分比:{(x - y) / x * 100}%')

n = 4
print(f'需要帧数:{n} 先前方法导出时刻点:{make_inference(0, 1, n)} 模型推理次数:{cnt}')
res = make_inference2(0, 1, n)
print(f'标准时刻:{res[0]} 最邻近时刻:{res[1]} 模型推理次数:{res[2]}')
Justin62628 commented 1 year ago

当前非任意时刻模型使用人数太少,以后不考虑整合非任意时刻模型相关的优化实现