bruinxiong / SENet.mxnet

:fire::fire: A MXNet implementation of Squeeze-and-Excitation Networks (SE-ResNext, SE-Resnet, SE-Inception-v4 and SE-Inception-Resnet-v2) :fire::fire:
Apache License 2.0
154 stars 53 forks source link

What is the difference between train_480_q90.rec and train_256_q90.rec? #8

Open miraclewkf opened 6 years ago

miraclewkf commented 6 years ago

What is the difference between train_480_q90.rec and train_256_q90.rec

shiyuanyin commented 5 years ago

@bruinxiong 你好作者, 我想请问你,对应SEnet 网络结构的地方,用mxnet 实现下面这句话,得到bn的bchw 然后reshape, 我不知道怎么实现,这是对SE的改进的一个结构 SGE,,因为mxnet 网络的symbol 不能直接得到bchw,我不知怎么实现。 b, c, h, w = x.size(), x = x.reshape(b * 64, -1, h, w)

下面是对应pytorch 实现

class SpatialGroupEnhance(nn.Module): # 3 2 1 hw is half, 311 is same size def init(self, groups = 64): super(SpatialGroupEnhance, self).init() self.groups = groups self.avg_pool = nn.AdaptiveAvgPool2d(1) self.weight = Parameter(torch.zeros(1, groups, 1, 1)) self.bias = Parameter(torch.ones(1, groups, 1, 1)) self.sig = nn.Sigmoid()

def forward(self, x): # (b, c, h, w) b, c, h, w = x.size() x = x.view(b self.groups, -1, h, w) ##reshape xn = x self.avg_pool(x) # x global pooling(h,w change 1) xn = xn.sum(dim=1, keepdim=True) #(b,1,h,w) t = xn.view(b self.groups, -1)
t = t - t.mean(dim=1, keepdim=True)
std = t.std(dim=1, keepdim=True) + 1e-5 t = t / std # normalize -mean/std t = t.view(b, self.groups, h, w) t = t self.weight + self.bias t = t.view(b self.groups, 1, h, w) x = x * self.sig(t) #in order to sigmod facter,this is group factor (0-1) x = x.view(b, c, h, w) #get to varying degrees of importance,Restoration dimension return x

bruinxiong commented 5 years ago

@shiyuanyin 您好,感谢您的关注。 通过您的留言,我大致明白你想做个什么事情。首先,可以通过mxnet.symbol.reshape实现reshape的功能,但是得到的output是symbol,不是数据本身的shape,想要debug或者想要得到shape,需要使用mxnet.symbol类下的infer_shape方法。具体到你的问题,你可以使用mxnet.symbol.reshape(data=你需要进行reshape的symbol,假定为上一层的输出,例如excitation, shape=(b*groups, -1, h, w)),这里的b是你的batchsize大小,可以通过前面指定或者使用mxnet.symbol.infer_shape获得。

shiyuanyin commented 5 years ago

@bruinxiong 这个实现的有点点不对,后来又改了 if use_sge: groups=64 batch = config.per_batch_size data_shape = (batch, 3, 112, 112) shape_size = bn3.infer_shape(data=data_shape)[1][0]

      x = mx.sym.reshape(bn3, shape=(0, -4, groups, -1, 0, 0))
      x = mx.sym.reshape(x, shape=(-3, 0, 0, 0))

      avg = mx.sym.Pooling(data=x, global_pool=True, kernel=(7, 7), pool_type='avg', name=name + '_sge_pool')
      body = mx.symbol.broadcast_mul(x, avg)
      body = mx.symbol.sum(data=body, axis=1, keepdims=True)

      body=mx.sym.reshape(body, shape=(0, -1))
      mean=mx.sym.mean(body,axis=1,keepdims=True)
      mean=mx.sym.reshape(mean, shape=(-4,-1,1)) #64 reshape (64,1)
      body=mx.sym.broadcast_sub(body,mean) # -mean

      std=mx.sym.mean(body**2,axis=1,keepdims=True) +1e-5 #(x-mean)^2then mean      
      std = mx.symbol.sqrt(std)

      body = mx.symbol.broadcast_div(body,std)
      body=mx.sym.reshape(body,shape=(-4,-1,groups,0))
      body=mx.sym.reshape(body,shape=(0,0,-4,-1,shape_size[-1]))

      sge_weight = mx.symbol.Variable(name+"_sge_weight"+name, shape=(1,64,1,1),lr_mult=1.0,wd_mult=1.0,init=mx.init.Normal(0.01))
      sge_bias = mx.symbol.Variable(name+"_sge_bias", shape=(1,64,1,1),lr_mult=1.0,wd_mult=1.0,init=mx.init.Normal(0.01))

      body = mx.symbol.broadcast_add(mx.symbol.broadcast_mul(body,sge_weight),sge_bias)
      #print(body.list_arguments())
      body = mx.sym.reshape(body, shape=(0, -4, groups,-1, 0, 0))
      body = mx.sym.reshape(body, shape=(-3, 0, 0, 0))
      body = mx.symbol.Activation(data=body, act_type='sigmoid', name=name + "_sge_sigmoid")
      body = mx.symbol.broadcast_mul(x, body)

      body= mx.sym.reshape(body, shape=(-4, -1,groups, 0, 0, 0))
      bn3 = mx.sym.reshape(body, shape=(0, -3, 0, 0))