Open YimianDai opened 3 years ago
from __future__ import division import os from mxnet.gluon.block import HybridBlock from mxnet.gluon import nn from mxnet.gluon.nn import BatchNorm from gluoncv.model_zoo.fcn import _FCNHead from mxnet import nd from .askc import LCNASKCFuse from model.atac.backbone import ATACBlockV1, conv1ATAC, DynamicCell from model.atac.convolution import LearnedCell, ChaDyReFCell, SeqDyReFCell, SK_ChaDyReFCell, \ SK_1x1DepthDyReFCell, SK_MSSpaDyReFCell, SK_SpaDyReFCell, Direct_AddCell, SKCell, \ SK_SeqDyReFCell, Sub_MSSpaDyReFCell, SK_MSSeqDyReFCell, iAAMSSpaDyReFCell from model.atac.convolution import \ LearnedConv, ChaDyReFConv, SeqDyReFConv, SK_ChaDyReFConv, \ SK_1x1DepthDyReFConv, SK_MSSpaDyReFConv, SK_SpaDyReFConv, Direct_AddConv, SKConv, \ SK_SeqDyReFConv # , SK_MSSeqDyReFConv from .activation import xUnit, SpaATAC, ChaATAC, SeqATAC, MSSeqATAC, MSSeqATACAdd, \ MSSeqATACConcat, MSSeqAttentionMap, xUnitAttentionMap from model.atac.fusion import Direct_AddFuse_Reduce, SK_MSSpaFuse, SKFuse_Reduce, LocalChaFuse, \ GlobalChaFuse, \ LocalGlobalChaFuse_Reduce, LocalLocalChaFuse_Reduce, GlobalGlobalChaFuse_Reduce, \ AYforXplusYChaFuse_Reduce, XplusAYforYChaFuse_Reduce, IASKCChaFuse_Reduce,\ GAUChaFuse_Reduce, SpaFuse_Reduce, ConcatFuse_Reduce, AXYforXplusYChaFuse_Reduce,\ BiLocalChaFuse_Reduce, BiGlobalChaFuse_Reduce, LocalGAUChaFuse_Reduce, GlobalSpaFuse,\ AsymBiLocalChaFuse_Reduce, BiSpaChaFuse_Reduce, AsymBiSpaChaFuse_Reduce, LocalSpaFuse, \ BiGlobalLocalChaFuse_Reduce # from gluoncv.model_zoo.resnetv1b import BasicBlockV1b from gluoncv.model_zoo.cifarresnet import CIFARBasicBlockV1 class ASKCResNetFPN(HybridBlock): def __init__(self, layers, channels, fuse_mode, act_dilation, classes=1, tinyFlag=False, norm_layer=BatchNorm, norm_kwargs=None, **kwargs): super(ASKCResNetFPN, self).__init__(**kwargs) self.layer_num = len(layers) self.tinyFlag = tinyFlag with self.name_scope(): stem_width = int(channels[0]) self.stem = nn.HybridSequential(prefix='stem') self.stem.add(norm_layer(scale=False, center=False, **({} if norm_kwargs is None else norm_kwargs))) if tinyFlag: self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, padding=1, use_bias=False)) self.stem.add(norm_layer(in_channels=stem_width*2)) self.stem.add(nn.Activation('relu')) else: self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2, padding=1, use_bias=False)) self.stem.add(norm_layer(in_channels=stem_width)) self.stem.add(nn.Activation('relu')) self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1, padding=1, use_bias=False)) self.stem.add(norm_layer(in_channels=stem_width)) self.stem.add(nn.Activation('relu')) self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, padding=1, use_bias=False)) self.stem.add(norm_layer(in_channels=stem_width*2)) self.stem.add(nn.Activation('relu')) self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) # self.head1 = _FCNHead(in_channels=channels[1], channels=classes) # self.head2 = _FCNHead(in_channels=channels[2], channels=classes) # self.head3 = _FCNHead(in_channels=channels[3], channels=classes) # self.head4 = _FCNHead(in_channels=channels[4], channels=classes) self.head = _FCNHead(in_channels=channels[1], channels=classes) self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0], channels=channels[1], stride=1, stage_index=1, in_channels=channels[1]) self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1], channels=channels[2], stride=2, stage_index=2, in_channels=channels[1]) self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2], channels=channels[3], stride=2, stage_index=3, in_channels=channels[2]) if self.layer_num == 4: self.layer4 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[3], channels=channels[4], stride=2, stage_index=4, in_channels=channels[3]) if self.layer_num == 4: self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[3], act_dilation=act_dilation) # channels[4] self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[2], act_dilation=act_dilation) # 64 self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[1], act_dilation=act_dilation) # 32 # if fuse_order == 'reverse': # self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[2]) # channels[2] # self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[3]) # channels[3] # self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] # elif fuse_order == 'normal': # self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] # self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] # self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0, norm_layer=BatchNorm, norm_kwargs=None): layer = nn.HybridSequential(prefix='stage%d_'%stage_index) with layer.name_scope(): downsample = (channels != in_channels) or (stride != 1) layer.add(block(channels, stride, downsample, in_channels=in_channels, prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for _ in range(layers-1): layer.add(block(channels, 1, False, in_channels=channels, prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs)) return layer def _fuse_layer(self, fuse_mode, channels, act_dilation): if fuse_mode == 'Direct_Add': fuse_layer = Direct_AddFuse_Reduce(channels=channels) elif fuse_mode == 'Concat': fuse_layer = ConcatFuse_Reduce(channels=channels) elif fuse_mode == 'SK': fuse_layer = SKFuse_Reduce(channels=channels) # elif fuse_mode == 'LocalCha': # fuse_layer = LocalChaFuse(channels=channels) # elif fuse_mode == 'GlobalCha': # fuse_layer = GlobalChaFuse(channels=channels) elif fuse_mode == 'LocalGlobalCha': fuse_layer = LocalGlobalChaFuse_Reduce(channels=channels) elif fuse_mode == 'LocalLocalCha': fuse_layer = LocalLocalChaFuse_Reduce(channels=channels) elif fuse_mode == 'GlobalGlobalCha': fuse_layer = GlobalGlobalChaFuse_Reduce(channels=channels) elif fuse_mode == 'IASKCChaFuse': fuse_layer = IASKCChaFuse_Reduce(channels=channels) elif fuse_mode == 'AYforXplusY': fuse_layer = AYforXplusYChaFuse_Reduce(channels=channels) elif fuse_mode == 'AXYforXplusY': fuse_layer = AXYforXplusYChaFuse_Reduce(channels=channels) elif fuse_mode == 'XplusAYforY': fuse_layer = XplusAYforYChaFuse_Reduce(channels=channels) elif fuse_mode == 'GAU': fuse_layer = GAUChaFuse_Reduce(channels=channels) elif fuse_mode == 'LocalGAU': fuse_layer = LocalGAUChaFuse_Reduce(channels=channels) elif fuse_mode == 'SpaFuse': fuse_layer = SpaFuse_Reduce(channels=channels, act_dialtion=act_dilation) elif fuse_mode == 'BiLocalCha': fuse_layer = BiLocalChaFuse_Reduce(channels=channels) elif fuse_mode == 'BiGlobalLocalCha': fuse_layer = BiGlobalLocalChaFuse_Reduce(channels=channels) elif fuse_mode == 'AsymBiLocalCha': fuse_layer = AsymBiLocalChaFuse_Reduce(channels=channels) elif fuse_mode == 'BiGlobalCha': fuse_layer = BiGlobalChaFuse_Reduce(channels=channels) elif fuse_mode == 'BiSpaCha': fuse_layer = BiSpaChaFuse_Reduce(channels=channels) elif fuse_mode == 'AsymBiSpaCha': fuse_layer = AsymBiSpaChaFuse_Reduce(channels=channels) # elif fuse_mode == 'LocalSpa': # fuse_layer = LocalSpaFuse(channels=channels, act_dilation=act_dilation) # elif fuse_mode == 'GlobalSpa': # fuse_layer = GlobalSpaFuse(channels=channels, act_dilation=act_dilation) # elif fuse_mode == 'SK_MSSpa': # # fuse_layer.add(SK_MSSpaFuse(channels=channels, act_dilation=act_dilation)) # fuse_layer = SK_MSSpaFuse(channels=channels, act_dilation=act_dilation) else: raise ValueError('Unknown fuse_mode') return fuse_layer def hybrid_forward(self, F, x): _, _, hei, wid = x.shape x = self.stem(x) # down 4, 32 c1 = self.layer1(x) # down 4, 32 c2 = self.layer2(c1) # down 8, 64 out = self.layer3(c2) # down 16, 128 if self.layer_num == 4: c4 = self.layer4(out) # down 32 if self.tinyFlag: c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4) # down 4 else: c4 = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16) # down 16 out = self.fuse34(c4, out) if self.tinyFlag: out = F.contrib.BilinearResize2D(out, height=hei//2, width=wid//2) # down 2, 128 else: out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8) # down 8, 128 out = self.fuse23(out, c2) if self.tinyFlag: out = F.contrib.BilinearResize2D(out, height=hei, width=wid) # down 1 else: out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4) # down 8 out = self.fuse12(out, c1) pred = self.head(out) if self.tinyFlag: out = pred else: out = F.contrib.BilinearResize2D(pred, height=hei, width=wid) # down 4 ######### reverse order ########## # up_c2 = F.contrib.BilinearResize2D(c2, height=hei//4, width=wid//4) # down 4 # fuse2 = self.fuse12(up_c2, c1) # down 4, channels[2] # # up_c3 = F.contrib.BilinearResize2D(c3, height=hei//4, width=wid//4) # down 4 # fuse3 = self.fuse23(up_c3, fuse2) # down 4, channels[3] # # up_c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4) # down 4 # fuse4 = self.fuse34(up_c4, fuse3) # down 4, channels[4] # ######### normal order ########## # out = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16) # out = self.fuse34(out, c3) # out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8) # out = self.fuse23(out, c2) # out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4) # out = self.fuse12(out, c1) # out = self.head(out) # out = F.contrib.BilinearResize2D(out, height=hei, width=wid) return out def evaluate(self, x): """evaluating network with inputs and targets""" return self.forward(x)