dvatterott / BMM_attentional_CNN

A CNN with an attentional module that I built while attending the brains minds and machines summer course
GNU General Public License v3.0
68 stars 33 forks source link

hey,i'm so sorry bother you but i really feel so crazy when i try to add your attention block to my model ,so can u help me fix it? #3

Closed Zhongan-Wang closed 5 years ago

Zhongan-Wang commented 5 years ago
import os
from keras.models import Model
from keras.layers.core import Dense, Dropout, Activation, Reshape, Permute, RepeatVector, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose, ZeroPadding2D
from keras.layers.pooling import AveragePooling2D, GlobalAveragePooling2D
from keras.layers import Input, Flatten
from keras.layers.merge import concatenate, multiply

from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras.layers.wrappers import TimeDistributed
from keras.layers.recurrent import GRU, LSTM
from keras.layers.wrappers import Bidirectional
from keras.engine.topology import Layer, InputSpec
from keras import initializers as initializations
import keras.backend as K
from attention_utils import get_activations, get_data_recurrent
import tensorflow as tf
from keras import backend as K
from keras import regularizers, constraints, initializers, activations
from keras.layers.recurrent import Recurrent
from keras.engine import InputSpec
from tdd import _time_distributed_dense
import numpy as np

def attention_3d_block(inputs):
    input_dim = int(inputs.shape[1])
    a = Permute((2, 1))(inputs)
    a = Dense(input_dim, activation='softmax')(a)
    a_probs = Permute((2, 1), name='attention_vec')(a)
    # print("a_probs shape :   ",a_probs.shape)
    output_attention_mul = multiply([inputs, a_probs], name='attention_mul')
    return output_attention_mul

def conv_block(input, growth_rate, dropout_rate=None, weight_decay=1e-4):
    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input)
    x = Activation('relu')(x)
    x = Conv2D(growth_rate, (3, 3), kernel_initializer='he_normal', padding='same')(x)
    if (dropout_rate):
        x = Dropout(dropout_rate)(x)
    return x

def dense_block(x, nb_layers, nb_filter, growth_rate, droput_rate=0.2, weight_decay=1e-4):
    for i in range(nb_layers):
        cb = conv_block(x, growth_rate, droput_rate, weight_decay)
        x = concatenate([x, cb], axis=-1)
        nb_filter += growth_rate
    return x, nb_filter

def transition_block(input, nb_filter, dropout_rate=None, pooltype=1, weight_decay=1e-4):
    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input)
    x = Activation('relu')(x)
    x = Conv2D(nb_filter, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False,
               kernel_regularizer=l2(weight_decay))(x)

    if (dropout_rate):
        x = Dropout(dropout_rate)(x)

    if (pooltype == 2):
        x = AveragePooling2D((2, 2), strides=(2, 2))(x)
    elif (pooltype == 1):
        x = ZeroPadding2D(padding=(0, 1))(x)
        x = AveragePooling2D((2, 2), strides=(2, 1))(x)
    elif (pooltype == 3):
        x = AveragePooling2D((2, 2), strides=(2, 1))(x)
    return x, nb_filter

def global_average_pooling(x):
    return K.mean(x, axis = (2, 3))

def global_average_pooling_shape(input_shape):
    return input_shape[0:2]

def change_shape1(x):
    x = K.reshape(K.transpose(x),(420,64))
    print("change shape",x.shape)
    return x

def att_shape(input_shape):
    return (input_shape[0][0],14,138,64)

def att_shape2(input_shape):
    return input_shape[0][0:4]

def attention_control(args):
    x,dense_2 = args
    print("attention shape ",x.shape)
    find_att = K.reshape(x,(7,60,64))
    print("find_att0 shape : ",find_att.shape)
    # find_att = K.transpose(find_att[:,:,:])
    find_att = K.mean(find_att,axis=1)
    find_att = find_att/K.sum(find_att,axis=1)
    print("---find_att :",find_att.shape)
    find_att = K.repeat_elements(find_att,300,axis=1)
    print("find_att 1 shape  ",find_att.shape)
    find_att = K.reshape(find_att,(1,7,60,64)) #(?, 16, 140, 64)
    print("find_att2 shape :",find_att.shape)
    return find_att
def dense_cnn(input, nclass):
    rnnunit = 256
    units = 256
    _dropout_rate = 0.2
    _weight_decay = 1e-4

    _nb_filter = 64
    # conv 64  5*5 s=2
    x0 = Conv2D(_nb_filter, (5, 5), strides=(2, 2), kernel_initializer='he_normal', padding='same',
               use_bias=False, kernel_regularizer=l2(_weight_decay))(input)# (?, 16, 140, 64)
    print("x0 shape",x0.shape)
    # 64 +  8 * 8 = 128
    x1, _nb_filter = dense_block(x0, 8, _nb_filter, 8, None, _weight_decay)
    # 128
    x2, _nb_filter = transition_block(x1, 128, _dropout_rate, 2, _weight_decay)

    # 128 + 8 * 8 = 192
    x, _nb_filter = dense_block(x2, 8, _nb_filter, 8, None, _weight_decay)
    # 192->128
    # x=attenton_cnn(x)
    x, _nb_filter = transition_block(x, 128, _dropout_rate, 2, _weight_decay)

    # 128 + 8 * 8 = 192
    x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay)

    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(x)
    # (None, 4, 35, 192)
    dense_1 = Lambda(global_average_pooling,output_shape=global_average_pooling_shape,name='dense_1')(x)  # (,32)
    dense_2 = Dense(10, activation='softmax', name='dense_2')(dense_1)  # (,10)

    con_shape1 = Lambda(change_shape1, output_shape=(64,), name='change_shape1')(x)
    print(con_shape1.shape)
    find_att = Dense(64, activation='softmax', name='att_con')(con_shape1)
    print("dense find att shape ",find_att.shape)
    find_att = Lambda(attention_control, output_shape=att_shape, name="att_con")([find_att, dense_2])
    zero_3a = ZeroPadding2D((8, 50), name='convzero_3')(find_att)
    print("==find att==",find_att.shape)
    apply_attention = multiply([x0,zero_3a])
    x = Activation('relu')(apply_attention)

    return x

def dense_blstm(input):
    pass

if __name__ == "__main__":
    input = Input(shape=(32, 280, 1), name='the_input')
    y_pred = dense_cnn(input, 15)
    basemodel = Model(inputs=input, outputs=y_pred)
    basemodel.summary()

it's a densent i want to add attention i would appreciate it if u can fix my code

dvatterott commented 5 years ago

Sorry but I am busy with other things and can't help with this.