keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.91k stars 19.45k forks source link

Additional input to the network mid-way (e.g. after conv. layers) #1330

Closed PiranjaF closed 8 years ago

PiranjaF commented 8 years ago

I'd like to extend the CNN-LSTM example to include additional information that should be passed through the CNN, but added directly to the LSTM alongside the features extracted with the CNN. How would I do that?

I imagine that is relevant in many use cases as there's often interesting metadata associated with images or text. For instance, I have a case similar to the CNN-LSTM text example, where I also have the time of day, geographical location and other information related to the text being examined. The dimensions of the added metadata to the LSTM is only ~10 with some features being one-hot encoded and others being continuous.

PiranjaF commented 8 years ago

I've found that this can be done with the Graph model. Thanks for such a useful library!

PiranjaF commented 8 years ago

To anyone interested - this is how it can be done using the Sequential model with a merge layer in the concat mode.


'''Train a recurrent convolutional network on the IMDB sentiment
classification task.

GPU command:
    THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python imdb_lstm.py

Get to 0.8498 test accuracy after 2 epochs. 41s/epoch on K520 GPU.
'''

from __future__ import print_function
import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Merge
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM, GRU, SimpleRNN
from keras.layers.convolutional import Convolution1D, MaxPooling1D
from keras.datasets import imdb

# Embedding
max_features = 20000
maxlen = 100
embedding_size = 128

# Convolution
filter_length = 3
nb_filter = 64
pool_length = 2

# LSTM
lstm_output_size = 70

# Training
batch_size = 30
nb_epoch = 2

'''
Note:
batch_size is highly sensitive.
Only 2 epochs are needed as the dataset is very small.
'''

print('Loading data...')
(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=max_features, test_split=0.2)

print(len(X_train), 'train sequences')
print(len(X_test), 'test sequences')

print('Pad sequences (samples x time)')
X_train = sequence.pad_sequences(X_train, maxlen=maxlen)
X_test = sequence.pad_sequences(X_test, maxlen=maxlen)
print('X_train shape:', X_train.shape)
print('X_test shape:', X_test.shape)

X_train_ex = np.hstack((np.zeros(shape=(20000, 40, 3)), np.ones(shape=(20000, 9, 3))))
print(X_train_ex.shape)
X_test_ex = X_train_ex[0:5000, :, :]
print('Build model...')

model = Sequential()

convnet = Sequential()
convnet.add(Embedding(max_features, embedding_size, input_length=maxlen))
convnet.add(Dropout(0.25))
convnet.add(Convolution1D(nb_filter=nb_filter,
                        filter_length=filter_length,
                        border_mode='valid',
                        activation='relu',
                        subsample_length=1))
convnet.add(MaxPooling1D(pool_length=pool_length))
print(convnet.output_shape)

extra = Sequential()
extra.add(Activation('linear', input_shape=(49,3)))
print(extra.output_shape)

model.add(Merge([convnet, extra], mode='concat', concat_axis=2))
print(model.output_shape)

model.add(LSTM(lstm_output_size))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              class_mode='binary')

print('Train...')
model.fit([X_train, X_train_ex], y_train, batch_size=batch_size, nb_epoch=nb_epoch,
          validation_data=([X_test, X_test_ex], y_test), show_accuracy=True)
score, acc = model.evaluate([X_test, X_test_ex], y_test, batch_size=batch_size,
                            show_accuracy=True)
print('Test score:', score)
print('Test accuracy:', acc)
alyato commented 8 years ago

@PiranjaF ,It‘s so cool.But i dont understand why the value is 'concat_axis=2'? Does the concat_axis be set the -1 or 1? Do you expmain its mean,please. Thanks.

kathanvyas commented 3 years ago

Has this been modified for the most latest version of keras?