tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.5k stars 3.49k forks source link

Decoder attends to two encoders #911

Open BinQian opened 6 years ago

BinQian commented 6 years ago

Description

Hi, I am trying to modify the Transformer a little to adapt to my problem, a query-base text summarisation task. For the input, I would like to use two encoders to encode both documents and query, and in the decoder I will attend to both inputs. I am not sure how to add more inputs in the model, is there a reference that I can look up?

vergilus commented 6 years ago

first, you need to customize your own problem in data generators(for example: my_problem.py). and a model with 2 encoders(for example: dual_encoder_transformer.py). if the encoder takes two inputs, you'll have to override some functions in data generator to generate "feature" dicts to pass to the model. for more, you can check my t2t projects for "dual_transformer" in which I tried to attend both asr wav inputs and txt inputs for an output. https://github.com/vergilus/tensor2tensor/blob/master/tensor2tensor/models/dual_transformer.py

martinpopel commented 6 years ago

See also https://github.com/ufal/neuralmonkey where this is already implemented (and many more, e.g. hierarchical attention in case of multiple encoders).

BinQian commented 6 years ago

@vergilus Thankx and I have been trying with your script with two text inputs. I have a problem with line 231~241 in _fast_decode, dual_transformer.py:
wav_inputs = tf.expand_dims(wav_inputs, axis=1) txt_inputs = tf.expand_dims(txt_inputs, axis=1) if len(wav_inputs.shape) < 5: wav_inputs = tf.expand_dims(wav_inputs, axis=4) if len(txt_inputs.shape) < 5: txt_inputs = tf.expand_dims(txt_inputs, axis=4) s = common_layers.shape_list(wav_inputs) batch_size = s[0] wav_inputs = tf.reshape(wav_inputs, [s[0] s[1], s[2], s[3], s[4]]) txt_inputs = tf.reshape(txt_inputs, [s[0] s[1], s[2], s[3], s[4]])

Where you are trying to reshape the two inputs to the same dimension. I am struggling since I am getting different dimensions for two text inputs, which does not make sense in my case. Do you have any idea which part defines the dimensions of the input? I have been reading the script everywhere and have no clue where to find it. Thx!