YimianDai / open-aff

code and trained models for "Attentional Feature Fusion"
729 stars 95 forks source link

ASKCResNetFPN #26

Open YimianDai opened 3 years ago

YimianDai commented 3 years ago
from __future__ import division
import os
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet.gluon.nn import BatchNorm
from gluoncv.model_zoo.fcn import _FCNHead
from mxnet import nd

from .askc import LCNASKCFuse

from model.atac.backbone import ATACBlockV1, conv1ATAC, DynamicCell
from model.atac.convolution import LearnedCell, ChaDyReFCell, SeqDyReFCell, SK_ChaDyReFCell, \
    SK_1x1DepthDyReFCell, SK_MSSpaDyReFCell, SK_SpaDyReFCell, Direct_AddCell, SKCell, \
    SK_SeqDyReFCell, Sub_MSSpaDyReFCell, SK_MSSeqDyReFCell, iAAMSSpaDyReFCell
from model.atac.convolution import \
    LearnedConv, ChaDyReFConv, SeqDyReFConv, SK_ChaDyReFConv, \
    SK_1x1DepthDyReFConv, SK_MSSpaDyReFConv, SK_SpaDyReFConv, Direct_AddConv, SKConv, \
    SK_SeqDyReFConv
    # , SK_MSSeqDyReFConv
from .activation import xUnit, SpaATAC, ChaATAC, SeqATAC, MSSeqATAC, MSSeqATACAdd, \
    MSSeqATACConcat, MSSeqAttentionMap, xUnitAttentionMap
from model.atac.fusion import Direct_AddFuse_Reduce, SK_MSSpaFuse, SKFuse_Reduce, LocalChaFuse, \
    GlobalChaFuse, \
    LocalGlobalChaFuse_Reduce, LocalLocalChaFuse_Reduce, GlobalGlobalChaFuse_Reduce, \
    AYforXplusYChaFuse_Reduce, XplusAYforYChaFuse_Reduce, IASKCChaFuse_Reduce,\
    GAUChaFuse_Reduce, SpaFuse_Reduce, ConcatFuse_Reduce, AXYforXplusYChaFuse_Reduce,\
    BiLocalChaFuse_Reduce, BiGlobalChaFuse_Reduce, LocalGAUChaFuse_Reduce, GlobalSpaFuse,\
    AsymBiLocalChaFuse_Reduce, BiSpaChaFuse_Reduce, AsymBiSpaChaFuse_Reduce, LocalSpaFuse, \
    BiGlobalLocalChaFuse_Reduce

# from gluoncv.model_zoo.resnetv1b import BasicBlockV1b
from gluoncv.model_zoo.cifarresnet import CIFARBasicBlockV1

class ASKCResNetFPN(HybridBlock):
    def __init__(self, layers, channels, fuse_mode, act_dilation, classes=1, tinyFlag=False,
                 norm_layer=BatchNorm, norm_kwargs=None, **kwargs):
        super(ASKCResNetFPN, self).__init__(**kwargs)

        self.layer_num = len(layers)
        self.tinyFlag = tinyFlag
        with self.name_scope():

            stem_width = int(channels[0])
            self.stem = nn.HybridSequential(prefix='stem')
            self.stem.add(norm_layer(scale=False, center=False,
                                     **({} if norm_kwargs is None else norm_kwargs)))
            if tinyFlag:
                self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1,
                                         padding=1, use_bias=False))
                self.stem.add(norm_layer(in_channels=stem_width*2))
                self.stem.add(nn.Activation('relu'))
            else:
                self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2,
                                         padding=1, use_bias=False))
                self.stem.add(norm_layer(in_channels=stem_width))
                self.stem.add(nn.Activation('relu'))
                self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1,
                                         padding=1, use_bias=False))
                self.stem.add(norm_layer(in_channels=stem_width))
                self.stem.add(nn.Activation('relu'))
                self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1,
                                         padding=1, use_bias=False))
                self.stem.add(norm_layer(in_channels=stem_width*2))
                self.stem.add(nn.Activation('relu'))
                self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1))

            # self.head1 = _FCNHead(in_channels=channels[1], channels=classes)
            # self.head2 = _FCNHead(in_channels=channels[2], channels=classes)
            # self.head3 = _FCNHead(in_channels=channels[3], channels=classes)
            # self.head4 = _FCNHead(in_channels=channels[4], channels=classes)

            self.head = _FCNHead(in_channels=channels[1], channels=classes)

            self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0],
                                           channels=channels[1], stride=1, stage_index=1,
                                           in_channels=channels[1])

            self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1],
                                           channels=channels[2], stride=2, stage_index=2,
                                           in_channels=channels[1])

            self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2],
                                           channels=channels[3], stride=2, stage_index=3,
                                           in_channels=channels[2])

            if self.layer_num == 4:
                self.layer4 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[3],
                                               channels=channels[4], stride=2, stage_index=4,
                                               in_channels=channels[3])

            if self.layer_num == 4:
                self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[3],
                                               act_dilation=act_dilation)  # channels[4]

            self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[2],
                                           act_dilation=act_dilation)  # 64
            self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[1],
                                           act_dilation=act_dilation)  # 32

            # if fuse_order == 'reverse':
            #     self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[2])  # channels[2]
            #     self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[3])  # channels[3]
            #     self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4])  # channels[4]
            # elif fuse_order == 'normal':
               #  self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4])  # channels[4]
               #  self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[4])  # channels[4]
               #  self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[4])  # channels[4]

    def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0,
                    norm_layer=BatchNorm, norm_kwargs=None):
        layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
        with layer.name_scope():
            downsample = (channels != in_channels) or (stride != 1)
            layer.add(block(channels, stride, downsample, in_channels=in_channels,
                            prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            for _ in range(layers-1):
                layer.add(block(channels, 1, False, in_channels=channels, prefix='',
                                norm_layer=norm_layer, norm_kwargs=norm_kwargs))
        return layer

    def _fuse_layer(self, fuse_mode, channels, act_dilation):
        if fuse_mode == 'Direct_Add':
            fuse_layer = Direct_AddFuse_Reduce(channels=channels)
        elif fuse_mode == 'Concat':
            fuse_layer = ConcatFuse_Reduce(channels=channels)
        elif fuse_mode == 'SK':
            fuse_layer = SKFuse_Reduce(channels=channels)
        # elif fuse_mode == 'LocalCha':
        #     fuse_layer = LocalChaFuse(channels=channels)
        # elif fuse_mode == 'GlobalCha':
        #     fuse_layer = GlobalChaFuse(channels=channels)
        elif fuse_mode == 'LocalGlobalCha':
            fuse_layer = LocalGlobalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'LocalLocalCha':
            fuse_layer = LocalLocalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'GlobalGlobalCha':
            fuse_layer = GlobalGlobalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'IASKCChaFuse':
            fuse_layer = IASKCChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'AYforXplusY':
            fuse_layer = AYforXplusYChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'AXYforXplusY':
            fuse_layer = AXYforXplusYChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'XplusAYforY':
            fuse_layer = XplusAYforYChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'GAU':
            fuse_layer = GAUChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'LocalGAU':
            fuse_layer = LocalGAUChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'SpaFuse':
            fuse_layer = SpaFuse_Reduce(channels=channels, act_dialtion=act_dilation)
        elif fuse_mode == 'BiLocalCha':
            fuse_layer = BiLocalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'BiGlobalLocalCha':
            fuse_layer = BiGlobalLocalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'AsymBiLocalCha':
            fuse_layer = AsymBiLocalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'BiGlobalCha':
            fuse_layer = BiGlobalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'BiSpaCha':
            fuse_layer = BiSpaChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'AsymBiSpaCha':
            fuse_layer = AsymBiSpaChaFuse_Reduce(channels=channels)
        # elif fuse_mode == 'LocalSpa':
        #     fuse_layer = LocalSpaFuse(channels=channels, act_dilation=act_dilation)
        # elif fuse_mode == 'GlobalSpa':
        #     fuse_layer = GlobalSpaFuse(channels=channels, act_dilation=act_dilation)
        # elif fuse_mode == 'SK_MSSpa':
        #     # fuse_layer.add(SK_MSSpaFuse(channels=channels, act_dilation=act_dilation))
        #     fuse_layer = SK_MSSpaFuse(channels=channels, act_dilation=act_dilation)
        else:
            raise ValueError('Unknown fuse_mode')

        return fuse_layer

    def hybrid_forward(self, F, x):

        _, _, hei, wid = x.shape

        x = self.stem(x)      # down 4, 32
        c1 = self.layer1(x)   # down 4, 32
        c2 = self.layer2(c1)  # down 8, 64
        out = self.layer3(c2)  # down 16, 128
        if self.layer_num == 4:
            c4 = self.layer4(out)  # down 32
            if self.tinyFlag:
                c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4)  # down 4
            else:
                c4 = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16)  # down 16
            out = self.fuse34(c4, out)
        if self.tinyFlag:
            out = F.contrib.BilinearResize2D(out, height=hei//2, width=wid//2)  # down 2, 128
        else:
            out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8)  # down 8, 128
        out = self.fuse23(out, c2)
        if self.tinyFlag:
            out = F.contrib.BilinearResize2D(out, height=hei, width=wid)  # down 1
        else:
            out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4)  # down 8
        out = self.fuse12(out, c1)

        pred = self.head(out)
        if self.tinyFlag:
            out = pred
        else:
            out = F.contrib.BilinearResize2D(pred, height=hei, width=wid)  # down 4

        ######### reverse order ##########
        # up_c2 = F.contrib.BilinearResize2D(c2, height=hei//4, width=wid//4)  # down 4
        # fuse2 = self.fuse12(up_c2, c1)  # down 4, channels[2]
        #
        # up_c3 = F.contrib.BilinearResize2D(c3, height=hei//4, width=wid//4)  # down 4
        # fuse3 = self.fuse23(up_c3, fuse2)  # down 4, channels[3]
        #
        # up_c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4)  # down 4
        # fuse4 = self.fuse34(up_c4, fuse3)  # down 4, channels[4]
        #

        ######### normal order ##########
        # out = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16)
        # out = self.fuse34(out, c3)
        # out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8)
        # out = self.fuse23(out, c2)
        # out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4)
        # out = self.fuse12(out, c1)
        # out = self.head(out)
        # out = F.contrib.BilinearResize2D(out, height=hei, width=wid)

        return out

    def evaluate(self, x):
        """evaluating network with inputs and targets"""
        return self.forward(x)