leondgarse / keras_cv_attention_models

Keras beit,caformer,CMT,CoAtNet,convnext,davit,dino,efficientdet,edgenext,efficientformer,efficientnet,eva,fasternet,fastervit,fastvit,flexivit,gcvit,ghostnet,gpvit,hornet,hiera,iformer,inceptionnext,lcnet,levit,maxvit,mobilevit,moganet,nat,nfnets,pvt,swin,tinynet,tinyvit,uniformer,volo,vanillanet,yolor,yolov7,yolov8,yolox,gpt2,llama2, alias kecam
MIT License
595 stars 95 forks source link

Update for EdgeNeXt #66

Closed whalefa1I closed 2 years ago

whalefa1I commented 2 years ago

I reproduced EdgeNeXt based on torch and your project, Is there any mistake with this code? Why can't it show all layers details,looks like it's missing some layers in “summary”

import tensorflow as tf
from tensorflow import keras
from keras_cv_attention_models.common_layers import (
    layer_norm, activation_by_name
)
from tensorflow.keras import initializers
from keras_cv_attention_models.attention_layers import (
    conv2d_no_bias,
    drop_block,
)
import math

BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
TF_BATCH_NORM_EPSILON = 0.001
LAYER_NORM_EPSILON = 1e-5

@tf.keras.utils.register_keras_serializable(package="EdgeNeXt")
class PositionalEncodingFourier(keras.layers.Layer):
    def __init__(self, hidden_dim=32, dim=768, temperature=10000):
        super(PositionalEncodingFourier, self).__init__()
        self.token_projection = tf.keras.layers.Conv2D(dim, kernel_size=1)
        self.scale = 2 * math.pi
        self.temperature = temperature
        self.hidden_dim = hidden_dim
        self.dim = dim
        self.eps = 1e-6

    def __call__(self, B, H, W, *args, **kwargs):
        mask_tf = tf.zeros([B, H, W])
        not_mask_tf = 1 - mask_tf
        y_embed_tf = tf.cumsum(not_mask_tf, axis=1)
        x_embed_tf = tf.cumsum(not_mask_tf, axis=2)
        y_embed_tf = y_embed_tf / (y_embed_tf[:, -1:, :] + self.eps) * self.scale  # 2 * math.pi
        x_embed_tf = x_embed_tf / (x_embed_tf[:, :, -1:] + self.eps) * self.scale  # 2 * math.pi
        dim_t_tf = tf.range(self.hidden_dim, dtype=tf.float32)
        dim_t_tf = self.temperature ** (2 * (dim_t_tf // 2) / self.hidden_dim)
        pos_x_tf = x_embed_tf[:, :, :, None] / dim_t_tf
        pos_y_tf = y_embed_tf[:, :, :, None] / dim_t_tf
        pos_x_tf = tf.reshape(tf.stack([tf.math.sin(pos_x_tf[:, :, :, 0::2]),
                                        tf.math.cos(pos_x_tf[:, :, :, 1::2])], axis=4),
                              shape=[B, H, W, self.hidden_dim])
        pos_y_tf = tf.reshape(tf.stack([tf.math.sin(pos_y_tf[:, :, :, 0::2]),
                                        tf.math.cos(pos_y_tf[:, :, :, 1::2])], axis=4),
                              shape=[B, H, W, self.hidden_dim])
        pos_tf = tf.concat([pos_y_tf, pos_x_tf], axis=-1)
        pos_tf = self.token_projection(pos_tf)

        return pos_tf

    def get_config(self):
        base_config = super().get_config()
        base_config.update({"token_projection": self.token_projection, "scale": self.scale,
                            "temperature": self.temperature, "hidden_dim": self.hidden_dim,
                            "dim": self.dim, "eps": self.eps})
        return base_config

def EdgeNeXt(input_shape=(256, 256, 3), depths=[3, 3, 9, 3], dims=[24, 48, 88, 168],
             global_block=[0, 0, 0, 3], global_block_type=['None', 'None', 'None', 'SDTA'],
             drop_path_rate=1, layer_scale_init_value=1e-6, head_init_scale=1., expan_ratio=4,
             kernel_sizes=[7, 7, 7, 7], heads=[8, 8, 8, 8], use_pos_embd_xca=[False, False, False, False],
             use_pos_embd_global=False, d2_scales=[2, 3, 4, 5], epsilon=1e-6, model_name='EdgeNeXt'):
    inputs = keras.layers.Input(input_shape, batch_size=2)

    nn = conv2d_no_bias(inputs, dims[0], kernel_size=4, strides=4, padding="valid", name="stem_")
    nn = layer_norm(nn, epsilon=epsilon, name='stem_')

    drop_connect_rates = tf.linspace(0, stop=drop_path_rate, num=int(
        sum(depths)))  # drop_connect_rates_split(num_blocks, start=0.0, end=drop_connect_rate)
    cur = 0
    for i in range(4):
        for j in range(depths[i]):
            if j > depths[i] - global_block[i] - 1:
                if global_block_type[i] == 'SDTA':
                    SDTA_encoder(dim=dims[i], drop_path=drop_connect_rates[cur + j],
                                 expan_ratio=expan_ratio, scales=d2_scales[i],
                                 use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i], name='stage_'+str(i)+'_SDTA_encoder_'+str(j))(nn)
                else:
                    raise NotImplementedError
            else:
                if i != 0 and j == 0:
                    nn = layer_norm(nn, epsilon=epsilon, name='stage_' + str(i) + '_')
                    nn = conv2d_no_bias(nn, dims[i], kernel_size=2, strides=2, padding="valid",
                                        name='stage_' + str(i) + '_')

                Conv_Encoder(dim=dims[i], drop_path=drop_connect_rates[cur + j],
                             layer_scale_init_value=layer_scale_init_value,
                             expan_ratio=expan_ratio, kernel_size=kernel_sizes[i], name='stage_'+str(i)+'_Conv_Encoder_'+str(j) + '_')(nn)  # drop_connect_rates[cur + j]

    model = keras.models.Model(inputs, nn, name=model_name)
    return model

@tf.keras.utils.register_keras_serializable(package="EdgeNeXt")
class Conv_Encoder(keras.layers.Layer):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4, kernel_size=7, epsilon=1e-6,
                 name=''):

        super(Conv_Encoder, self).__init__()
        self.encoder_name = name
        self.gamma = tf.Variable(layer_scale_init_value * tf.ones(dim), trainable=True,
                                 name=name + 'gamma') if layer_scale_init_value > 0 else None
        self.drop_path = drop_path
        self.dim = dim
        self.expan_ratio = expan_ratio
        self.kernel_size = kernel_size
        self.epsilon = epsilon

    def __call__(self, x, *args, **kwargs):
        inputs = x
        x = keras.layers.Conv2D(self.dim, kernel_size=self.kernel_size, padding="SAME", name=self.encoder_name +'Conv2D')(x)
        x = layer_norm(x, epsilon=self.epsilon, name=self.encoder_name)
        x = keras.layers.Dense(self.expan_ratio * self.dim)(x)
        x = activation_by_name(x, activation="gelu")
        x = keras.layers.Dense(self.dim)(x)
        if self.gamma is not None:
            x = self.gamma * x

        x = inputs + drop_block(x, drop_rate=0.)

        return x

    def get_config(self):
        base_config = super().get_config()
        base_config.update({"gamma": self.gamma, "drop_path": self.drop_path,
                            "dim": self.dim, "expan_ratio": self.expan_ratio,
                            "kernel_size": self.kernel_size})
        return base_config

@tf.keras.utils.register_keras_serializable(package="EdgeNeXt")
class SDTA_encoder(keras.layers.Layer):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4,
                 use_pos_emb=True, num_heads=8, qkv_bias=True, attn_drop=0., drop=0., scales=1, zero_gamma=False,
                 activation='gelu', use_bias=False, name='sdf'):
        super(SDTA_encoder, self).__init__()
        self.expan_ratio = expan_ratio
        self.width = max(int(math.ceil(dim / scales)), int(math.floor(dim // scales)))
        self.width_list = [self.width] * (scales - 1)
        self.width_list.append(dim - self.width * (scales - 1))
        self.dim = dim
        self.scales = scales
        if scales == 1:
            self.nums = 1
        else:
            self.nums = scales - 1
        self.pos_embd = None
        if use_pos_emb:
            self.pos_embd = PositionalEncodingFourier(dim=dim)
        self.xca = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.gamma_xca = tf.Variable(layer_scale_init_value * tf.ones(dim), trainable=True,
                                     name=name + 'gamma') if layer_scale_init_value > 0 else None
        self.gamma = tf.Variable(layer_scale_init_value * tf.ones(dim), trainable=True,
                                 name=name + 'gamma') if layer_scale_init_value > 0 else None
        self.drop_rate = drop_path
        self.drop_path = keras.layers.Dropout(drop_path)
        gamma_initializer = tf.zeros_initializer() if zero_gamma else tf.ones_initializer()
        self.norm = keras.layers.LayerNormalization(epsilon=LAYER_NORM_EPSILON, gamma_initializer=gamma_initializer,
                                                    name=name and name + "ln")
        self.norm_xca = keras.layers.LayerNormalization(epsilon=LAYER_NORM_EPSILON, gamma_initializer=gamma_initializer,
                                                        name=name and name + "norm_xca")
        self.activation = activation
        self.use_bias = use_bias

    def get_config(self):
        base_config = super().get_config()
        base_config.update({"width": self.width, "dim": self.dim,
                            "nums": self.nums, "pos_embd": self.pos_embd,
                            "xca": self.xca, "gamma_xca": self.gamma_xca,
                            "gamma": self.gamma, "norm": self.norm,
                            "activation": self.activation, "use_bias": self.use_bias,
                            })
        return base_config

    def __call__(self, inputs, *args, **kwargs):
        x = inputs
        spx = tf.split(inputs, self.width_list, axis=-1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = keras.layers.Conv2D(self.width, kernel_size=3, padding='SAME')(sp)  # , groups=self.width
            if i == 0:
                out = sp
            else:
                out = tf.concat([out, sp], -1)
        inputs = tf.concat([out, spx[self.nums]], -1)

        # XCA
        B, H, W, C = inputs.shape
        inputs = tf.reshape(inputs, (-1, H * W, C))  # tf.transpose(), perm=[0, 2, 1])

        if self.pos_embd:
            pos_encoding = tf.reshape(self.pos_embd(B, H, W), (-1, H * W, C))
            inputs += pos_encoding

        if self.gamma_xca is not None:
            inputs = self.gamma_xca * inputs
        input_xca = self.gamma_xca * self.xca(self.norm_xca(inputs))
        inputs = inputs + drop_block(input_xca, drop_rate=self.drop_rate, name="SDTA_encoder_")
        inputs = tf.reshape(inputs, (-1, H, W, C))

        # Inverted Bottleneck
        inputs = self.norm(inputs)
        inputs = keras.layers.Conv2D(self.expan_ratio * self.dim, kernel_size=1, use_bias=self.use_bias)(inputs)
        inputs = activation_by_name(inputs, activation=self.activation)
        inputs = keras.layers.Conv2D(self.dim, kernel_size=1, use_bias=self.use_bias)(inputs)
        if self.gamma is not None:
            inputs = self.gamma * inputs

        x = x + self.drop_path(inputs)
        return x

@tf.keras.utils.register_keras_serializable(package="EdgeNeXt")
class XCA(keras.layers.Layer):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., name=""):
        super(XCA, self).__init__()
        self.num_heads = num_heads
        self.temperature = tf.Variable(tf.ones(num_heads, 1, 1), trainable=True, name=name + 'gamma')

        self.qkv = keras.layers.Dense(dim * 3, use_bias=qkv_bias)
        self.attn_drop = keras.layers.Dropout(attn_drop)
        self.k_ini = initializers.GlorotUniform()
        self.b_ini = initializers.Zeros()
        self.proj = keras.layers.Dense(dim, name="out",
                                       kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.proj_drop = keras.layers.Dropout(proj_drop)

    def __call__(self, inputs, training=None, *args, **kwargs):
        input_shape = inputs.shape
        qkv = self.qkv(inputs)
        qkv = tf.reshape(qkv, (input_shape[0], input_shape[1], 3,
                               self.num_heads,
                               input_shape[2] // self.num_heads))  # [batch, hh * ww, 3, num_heads, dims_per_head]
        qkv = tf.transpose(qkv, perm=[2, 0, 3, 4, 1])  # [3, batch, num_heads, dims_per_head, hh * ww]
        query, key, value = tf.split(qkv, 3, axis=0)  # [batch, num_heads, dims_per_head, hh * ww]

        norm_query, norm_key = tf.nn.l2_normalize(tf.squeeze(query), axis=-1, epsilon=1e-6), \
                               tf.nn.l2_normalize(tf.squeeze(key), axis=-1, epsilon=1e-6)
        attn = tf.matmul(norm_query, norm_key, transpose_b=True)
        attn = tf.transpose(tf.transpose(attn, perm=[0, 2, 3, 1]) * self.temperature, perm=[0, 3, 2, 1])

        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn, training=training)  # [batch, num_heads, hh * ww, hh * ww]

        x = tf.matmul(attn, value)  # [batch, num_heads, hh * ww, dims_per_head]
        x = tf.reshape(x, [input_shape[0], input_shape[1], input_shape[2]])

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def get_config(self):
        base_config = super().get_config()
        base_config.update({"num_heads": self.num_heads, "temperature": self.temperature,
                            "qkv": self.qkv, "attn_drop": self.attn_drop,
                            "proj": self.proj, "proj_drop": self.proj_drop})
        return base_config

def edgenext_xx_small(pretrained=False, **kwargs):
    # 1.33M & 260.58M @ 256 resolution
    # 71.23% Top-1 accuracy
    # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
    # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
    # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
    model = EdgeNeXt(depths=[2, 2, 6, 2], dims=[24, 48, 88, 168], expan_ratio=4,
                     global_block=[0, 1, 1, 1],
                     global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
                     use_pos_embd_xca=[False, True, False, False],
                     kernel_sizes=[3, 5, 7, 9],
                     heads=[4, 4, 4, 4],
                     d2_scales=[2, 2, 3, 4],
                     **kwargs)

    return model

def edgenext_x_small(pretrained=False, **kwargs):
    # 2.34M & 538.0M @ 256 resolution
    # 75.00% Top-1 accuracy
    # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
    # Jetson FPS=31.61 versus 28.49 for MobileViT_XS
    # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
    model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[32, 64, 100, 192], expan_ratio=4,
                     global_block=[0, 1, 1, 1],
                     global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
                     use_pos_embd_xca=[False, True, False, False],
                     kernel_sizes=[3, 5, 7, 9],
                     heads=[4, 4, 4, 4],
                     d2_scales=[2, 2, 3, 4],
                     **kwargs)

    return model

def edgenext_small(pretrained=False, **kwargs):
    # 5.59M & 1260.59M @ 256 resolution
    # 79.43% Top-1 accuracy
    # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
    # Jetson FPS=20.47 versus 18.86 for MobileViT_S
    # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
    model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[48, 96, 160, 304], expan_ratio=4,
                     global_block=[0, 1, 1, 1],
                     global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
                     use_pos_embd_xca=[False, True, False, False],
                     kernel_sizes=[3, 5, 7, 9],
                     d2_scales=[2, 2, 3, 4],
                     **kwargs)

    return model

if __name__ == '__main__':
    model = edgenext_small()
    model.summary()
    # from download_and_load import keras_reload_from_torch_model
    # keras_reload_from_torch_model(
    #     'D:\GitHub\EdgeNeXt\edgenext_small.pth',
    #     keras_model=model,
    #     # tail_align_dict=tail_align_dict,
    #     # full_name_align_dict=full_name_align_dict,
    #     # additional_transfer=additional_transfer,
    #     input_shape=(256, 256),
    #     do_convert=True,
    #     save_name="adaface_ir101_webface4m.h5",
    # )
leondgarse commented 2 years ago

Just 2 issues:

I also just read the article today. :)

whalefa1I commented 2 years ago

照猫画虎!looking forward to ur modify & update~~

whalefa1I commented 2 years ago

And another backbone : FYI -> https://paperswithcode.com/paper/efficientformer-vision-transformers-at

leondgarse commented 2 years ago

两个都更新上了!

whalefa1I commented 2 years ago

感动!!!