allenai / XNOR-Net

ImageNet classification using binary Convolutional Neural Networks
https://xnor.ai/
Other
856 stars 239 forks source link

BinConvolution doen't seem to match paper #41

Open honglh opened 4 years ago

honglh commented 4 years ago

(*) means conv operation, o is element-wise product

  # https://github.com/allenai/XNOR-Net/blob/master/models/alexnetxnor.lua#L16
   local function BinConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
         local C= nn.Sequential()
          C:add(nn.SpatialBatchNormalization(nInputPlane,1e-4,false))
          C:add(activation())
          C:add(cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH))
       return C
   end

In this implementation, after input activation() next step is direct convolution cudnn.SpatialConvolution() with parameters. But The paper's algorithm for Input binarization is:

I * W ~= (sign(I) (*) sign(W)) o Ka = ((sign(I) (*) sign(W)) o (A (*) k)a
where
A =  torch.mean(input.abs(), 1, keepdim=True)
k  = an averaging kernel with value 1/(w*h)

so A(*)k is to averaging each input element with its neighboring elements. This is missing in the current implementation, where only (sign(I) (*) sign(W)) o a is calculated.

C:add(activation())
C:add(cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH))

To capture the convolution with A and k from the paper, I would expect pseudo code like this in function BinConvolution() in python

x = BinActiveZ(x)
# <=== === === === === === === === === === START
A = mean #shape N, 1, W, H
sign_I = x #shape  N, Cin, W, H
kH = self.conv.weight.shape[2] #kernel height
kW = self.conv.weight.shape[3] #kernel width
k = torch.ones(1, 1, kH, kW) * (1/(kH*kW)) #setup averaging kernel k
conv_Ak = torch.nn.Conv2d(1, 1, kH, kW, padding=(kH//2, kW//2))
conv_Ak.weight.data = k
K = conv_Ak(A) #shape N, 1, W, H

#now calculate sign_I (*)sign_W o Ka
# since self.conv.weight is already binarized by binarizeConvParams() before batch starts, 
# the `a` in `Ka` is included in `self.conv(x)` . The only missing part is `mul(K)`
# Hence:
x = self.conv(x).mul(K) 
# <=== === === === === === === === === === END

Can you check if my understanding of the discrepancy is correct?

leejiajun commented 3 years ago

@honglh Did you figure it out? Besides, I am curious on the first layer which uses float point but the paper said it uses +1 and -1. Did I misunderstand?