eehoeskrap / PaperReview

꾸준희의 꾸준하게 논문 읽기 프로젝트 ✨
8 stars 0 forks source link

AdderNet: Do We Really Need Multiplications in Deep Learning? #18

Open eehoeskrap opened 1 year ago

eehoeskrap commented 1 year ago

Paper : https://arxiv.org/abs/1912.13200 GitHub : https://github.com/huawei-noah/AdderNet

CVPR 2020에서 소개된 AdderNet은 기존 딥러닝에서 multiplication operation이 addition operation 연산 보다 계산 복잡도가 높기 때문에 더하는 연산들을 곱셈으로 만드는 Adder Network를 제안합니다.

기존 Conv 구조는 Conv Filter 사이의 곱셈 연산을 통해 weight를 업데이트 하게 되는데, 이러한 방식은 GPU 메모리와 전력 소모가 많아서 mobile 환경에서 사용하기가 어렵다고 합니다. 그래서 기존 연구에서는 binarization 기반의 binarized filter를 활용한 binary adder 기법들이 소개되어왔고, 나름 준수한 성능을 보였다고 합니다. 하지만 이러한 binarization 아이디어는 기존 신경망의 mAP를 보장할 수 없고, 수렴 속도와 학습률이 저하되는 등 학습 과정이 안정화되지 않는다고 합니다.

feature를 visualization 해보면 다음과 같습니다.

image

기존 CNN 구조 같은 경우 각도로 classification을 수행하게 되지만, AdderNet 같은 경우 클러스터링 된 점의 좌표를 활용하여 class를 분류하게 됩니다. 그 이유는 CNN은 필터와 입력 간의 cross correlation을 계산 하는데, 이 값들이 정규화 되면서 conv 연산은 두 vector간 cosine 거리를 계산하는 것과 같다고 합니다. 반면 AdderNet은 l1 norm을 활용하여 classification을 하기 때문에 각각 다른 중심을 가진 포인트로 클러스터링 된다고 하네요. 즉 이 값들은 필터와 입력값 사이의 유사도로 활용할 수 있다고 합니다.

기존 conv 연산은 아래와 같이 이루어집니다. X가 입력 값, F가 filter이고, S는 사전에 정의된 유사도 측정 함수 입니다.

image

AdderNet 연산은 아래와 같습니다. 위에서 함수 S를 제거하고 아래와 같은 수식으로 정의됩니다. L1 norm을 사용하니까 식이 이렇게 되는거 같네요.

image

이 수식의 출력은 항상 음수 값을 갖는다고 합니다. 그 대신 여기서 Batch Norm을 활용하여 정규화한다고 하네요. 정규화 과정에서 곱셈이 포함되어있긴 하지만 conv에 있는 곱셈보다는 훨씬 연산량이 적다고 하네요. 그리고 기존 Conv layer를 Adder layer로 대체하여 AddNets을 구현할 수 있다고 합니다.

import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Function
import math

def adder2d_function(X, W, stride=1, padding=0):
    n_filters, d_filter, h_filter, w_filter = W.size()
    n_x, d_x, h_x, w_x = X.size()

    h_out = (h_x - h_filter + 2 * padding) / stride + 1
    w_out = (w_x - w_filter + 2 * padding) / stride + 1

    h_out, w_out = int(h_out), int(w_out)
    X_col = torch.nn.functional.unfold(X.view(1, -1, h_x, w_x), h_filter, dilation=1, padding=padding, stride=stride).view(n_x, -1, h_out*w_out)
    X_col = X_col.permute(1,2,0).contiguous().view(X_col.size(1),-1)
    W_col = W.view(n_filters, -1)

    out = adder.apply(W_col,X_col)

    out = out.view(n_filters, h_out, w_out, n_x)
    out = out.permute(3, 0, 1, 2).contiguous()

    return out

class adder(Function):
    @staticmethod
    def forward(ctx, W_col, X_col):
        ctx.save_for_backward(W_col,X_col)
        output = -(W_col.unsqueeze(2)-X_col.unsqueeze(0)).abs().sum(1)
        return output

    @staticmethod
    def backward(ctx,grad_output):
        W_col,X_col = ctx.saved_tensors
        grad_W_col = ((X_col.unsqueeze(0)-W_col.unsqueeze(2))*grad_output.unsqueeze(1)).sum(2)
        grad_W_col = grad_W_col/grad_W_col.norm(p=2).clamp(min=1e-12)*math.sqrt(W_col.size(1)*W_col.size(0))/5
        grad_X_col = (-(X_col.unsqueeze(0)-W_col.unsqueeze(2)).clamp(-1,1)*grad_output.unsqueeze(1)).sum(0)

        return grad_W_col, grad_X_col

class adder2d(nn.Module):

    def __init__(self,input_channel,output_channel,kernel_size, stride=1, padding=0, bias = False):
        super(adder2d, self).__init__()
        self.stride = stride
        self.padding = padding
        self.input_channel = input_channel
        self.output_channel = output_channel
        self.kernel_size = kernel_size
        self.adder = torch.nn.Parameter(nn.init.normal_(torch.randn(output_channel,input_channel,kernel_size,kernel_size)))
        self.bias = bias
        if bias:
            self.b = torch.nn.Parameter(nn.init.uniform_(torch.zeros(output_channel)))

    def forward(self, x):
        output = adder2d_function(x,self.adder, self.stride, self.padding)
        if self.bias:
            output += self.b.unsqueeze(0).unsqueeze(2).unsqueeze(3)

        return output

결과를 보면 다음과 같습니다.

image image image