THUNLP-MT / THUMT

An open-source neural machine translation toolkit developed by Tsinghua Natural Language Processing Group
BSD 3-Clause "New" or "Revised" License
703 stars 197 forks source link

upgrade the transmission of state for decoding #44

Closed znculee closed 6 years ago

znculee commented 6 years ago

Researchers in NMT usually build their new network structure for testing new ideas. Sometimes, they might need use the state to transmit the middle results from encoder to decoder. However, in the original code, the whole state is not transmitted to the decoding part, instead of constructing a new state dictionary. I made a revision here without changing any functions of the code, but make other researchers revise the base code easier.

Glaceon31 commented 6 years ago

Thanks for your idea. Although it is helpful to transmit the middle results when testing new ideas, it will cause state to transmit unused data otherwise. It is recommended to add the transmission only when needed.

znculee commented 6 years ago

@Glaceon31 Thanks for your comments. This did introduce unused states when building graph, but the unsed state would only affect the graph building stage, but not the tensorflow running session, therefore it would not cost any computation. However, this will easy the testing new ideas. Anyway, if your think construct a new state is better, I will close this pull request. Thank you for your contribution.

Glaceon31 commented 6 years ago

It seems that I misunderstood the usage of state in my previous comment. state is a dict which carries only two items: state["encoder"] and state["decoder"]. Your implementation in the function decoding_graph() should be identical to the origin implementation because state is fed by step_log_probs, next_state = func(flat_seqs, flat_state)(line 60 in inference.py) in each decoding step. Please point out if I miss something.

By the way, what is the modification in the function infer_shape_invariants() doing? I am not familiar with the usage of this function.

znculee commented 6 years ago

Yes, in my implementation, the state in the function decoding_graph is identical to the origin implementation. In the original implementation, the state is a dictionary containing state["encoder"] and state["decoder"]. In my implementation, the state already contains the state["encoder"], so we only need to assign a new key-value, namely state["decoder"], to the original state.

In terms of the functioninfer_shape_invariants(), it firstly get the shape of variables and change some dimension to None, and this is used to let tf.while_loop be able to change the size of the variants in the particularly dimension. In the original implementation, the first dimension, corresponding batch size and the last dimension, corresponding hidden size, are not allowed to be changed. Actually, the first dimension in some state may not be the batch size, so I relax this limitations to only not allow to change the last dimension, namely the hidden size.

For example, when I testing new idea, I may use src_embedding in the function encoding_graph, however, the most efficient way to transmit this information is using state. Any this may lead original implementation failed. So in my implementation, state can be freely added any information as you want.

Glaceon31 commented 6 years ago

Thanks for the explanation. The change on state is good because it is better than returning a new dict. The change on infer_shape_invariants() is at least not necessary for the published version.

znculee commented 6 years ago

@Glaceon31 Thanks for your comments.