Open Eugen2525 opened 5 years ago
So I kinda found solution, but it "works" partially...
So when I run the code in Keras as below, the model works:
def transformer_code(inputLayer):
hparams = transformer.transformer_base()
encoder = transformer.TransformerEncoder(hparams, mode=tf.estimator.ModeKeys.TRAIN)
x = keras.backend.expand_dims(inputLayer, axis=2)
y = encoder({"inputs": x, "targets": 0, "target_space_id": 0})
y = keras.backend.squeeze(y[0], 2)
return y
def trainModel(args, trainInput, trianOutput, testInput, testOutput, taskName, tags):
inputLayer = keras.layers.Input(shape=(len(trainInput[0]),
len(trainInput[0][0])), dtype='float32')
inputAfterDense = keras.layers.Dense(512, activation='relu')(inputLayer)
crfLayer = CRF(len(tags), sparse_target=True, name='result')
y = keras.layers.Lambda(transformer_code)(inputAfterDense )
modelPred = crfLayer(y)
model = keras.Model(inputs=inputLayer, outputs=modelPred)
model.compile(
optimizer='adam',
loss = {'result': crfLayer.loss_function},
metrics={'result': crfLayer.accuracy}
)
print 'finish model setting'
print model.summary()
but if I remove the Dense layer inputAfterDense = keras.layers.Dense(512, activation='relu')(inputLayer)
the training breaks and prediction accuracy is nearly zero all the time.
Why is that?
So I have solved the issue with some other library, but could not fix with this one, so the question is still valid.
@Eugen2525 did you find a solution?
No, only way is to use transformer library for Keras. This issue was not solved by me
So I have solved the issue with some other library, but could not fix with this one, so the question is still valid.
What did you do?
So I have solved the issue with some other library, but could not fix with this one, so the question is still valid.
What did you do?
I used this custom library:
Description
So I want to turn below Keras code which uses bidirectional LSTM into transformer:
lstmLayer = keras.layers.Bidirectional( keras.layers.CuDNNLSTM(args.rnnSize, return_sequences = True, recurrent_initializer = 'glorot_uniform' ) )(inputLayer)
How can this be accomplished? ...
Environment information
Linux, python 2.7
Python 2.7.15 :: Anaconda, Inc.
For bugs: reproduction and error logs