neulab / xnmt

eXtensible Neural Machine Translation
Other
185 stars 44 forks source link

Interaction between embedders and en-/decoders #556

Closed armatthews closed 5 years ago

armatthews commented 5 years ago

Right now translator seems to assume that one word is fed into the decoder for each time step, and that each word is simply an integer that can be looked up in an embedding table: self.decoder.add_input(dec_state, self.trg_embedder.embed(input_word))

This assumption breaks for some interesting models, such as using an RNNG as a decoder or my morphology thing. In these models the input "word" may be a more complicated object. In the former case it's a tuple of (action, word_id). In the latter it's a more complicated object involving a word, zero or more morpheme sequences, and a character sequence.

I propose we remedy this by moving the trg_embedder from translator into the decoder class. This means the above code would simply become self.decoder.add_input(dec_state, input_word), and the decoder would store a pointer to the trg_embedder and call embed (or otherwise) itself.

We might also want to do something similar with the src_embedder and encode class to maintain parallelism.

msperber commented 5 years ago

I think your suggestion sounds reasonable, although I don't have time to think through the implications right now. Just wanted to mention one place through which we handled more complicated source-side inputs previously: the start_sent() event can be implemented by any class anywhere in XNMT and is passed the current batch of sentences, so that the classes can use it for whatever purposes. It's for example used to give the lattice LSTM access to the lattice structure: https://github.com/neulab/xnmt/blob/master/xnmt/transducers/lattice.py#L59 I don't think we've done that for the target side, but would be easy to extend.

That being said, I'm not sure if that would be enough to handle the use cases you mentioned, and this kind of passing around data through global events is always a bit ugly, so it might be nice to find a better solution.

armatthews commented 5 years ago

On the source side we can use start_sent(), as you said, though the elegance of that approach is up for debate. On the target side, however, we don't have any way of pre-passing anything into the decoder. There needs to be some way for the decoder to get "raw" words, instead of just embeddings.

I took a stab at moving the embedder from the translator into the decoder here, and it seems to be working. Does this seem like a reasonable change to you (/ others)?