aymericdamien / TensorFlow-Examples

TensorFlow Tutorial and Examples for Beginners (support TF v1 & v2)
Other
43.38k stars 14.95k forks source link

How can I merge features extracted from two different models and feed to the another model in TensorFlow? #138

Open ghost opened 7 years ago

ghost commented 7 years ago

I am new to TensorFlow. I am trying to implement the architecture shown below.

uybmy

In keras, its very easy to do this- below is the example-

first_model = Sequential()
first_model.add(LSTM(output_dim, input_shape=(m, input_dim)))

second_model = Sequential()
second_model.add(LSTM(output_dim, input_shape=(n input_dim)))

model = Sequential()
model.add(Merge([first_model, second_model], mode='concat'))

I have written the code for two different lstm models. I need help in merging these two models using tensorflow?

-------many to one rnn1------

lstm_size_1 = 64
number_of_layers_1 = 4

stacked_lstm_1 = tf.contrib.rnn.MultiRNNCell(
        [tf.contrib.rnn.BasicLSTMCell(lstm_size_1, forget_bias=1.0, state_is_tuple=False)
                for _ in range(number_of_layers_1)], state_is_tuple=False)

outputs_1, state_1 = tf.nn.dynamic_rnn(stacked_lstm_1, model_input_1,
                                   sequence_length=length_sample_1,
                                   dtype=tf.float32)

-------many to one rnn2-----

lstm_size_2 = 64
number_of_layers_2 = 4

stacked_lstm_2 = tf.contrib.rnn.MultiRNNCell(
        [tf.contrib.rnn.BasicLSTMCell(lstm_size_2, forget_bias=1.0, state_is_tuple=False)
                for _ in range(number_of_layers_2)], state_is_tuple=False)

outputs_2, state_2 = tf.nn.dynamic_rnn(stacked_lstm_2, model_input_2,
                                   sequence_length=length_sample_2,
                                   dtype=tf.float32)
msampathkumar commented 7 years ago

One question. Would n't you model lose valuable information like connecting x variables and y-variables? (From the Image, I am guessing your goal is to predict a missing term using its neighbouring words or characters.)

If my guess is right then, just use all these variables as one input and use one model.