tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.34k stars 3.47k forks source link

*help* transformer adding gibberish at the end of the line during decoding #740

Open sebastian-nehrdich opened 6 years ago

sebastian-nehrdich commented 6 years ago

Hi there,

I have trained a transformer that is giving very good and precise results in most of the sentences for my problem (Sanskrit word segmentation). However in about 10% of the sentences it strangely adds a lot of gibberish at the end of the phrase, effectively doubling the length of the phrase. This makes these sentences totally useless. It seems to happen a little bit more frequently with longer sentences; the gibberish usually consists of fragments of the words that occur in the input sentence, but they are presented out of order. It is disturbing since apart from the gibberish the quality of the output is very very high so it seems to be learning well. I observed that there is more gibberish during the beginning of the training, and at some point after training for a couple of days the gibberish again increases. I am training with transformer_base_single_gpu-parameters, vocab_size is 5k.

stefan-it commented 6 years ago

Could you please specify the exact version of T2T? An example output would also be great.

I had problems in the decoding process using 1.5.7 (even end-of-sentence marker were missing). 1.5.6 worked and also the latest 1.6.0.

sebastian-nehrdich commented 6 years ago

t2t is 1.5.6, tensorflow is 1.5.0. The undesired output looks like this (first line is how it should look, second line is how it actually looks, the bold part is undesired): BagavataH- jQAna-darSanam pravartate SrAvastyAm pUrvakEH samyaksambudDEH-mahA-prAtihAryam vidarSitam hitAya prANinAm-iti

BagavataH- jQAna-darSanam pravartate SrAvastyAm pUrvakEH samyaksambudDEH-mahA-prAtihAryam vidarSitam hitAya prANinAm-iti BagavataH- jQAna-darSanam pravartate SrAvastyAm pUrvakEH samyaksambudDEH-mahA-prAtihAryam vidarSitam hitAya prANinAm-iti sambudDEH-mahA-prAtihAryam vidarSitam hitAya prANinAm-iti vidarSitam hitAya prANinAm-iti BagavataH- jQAna-darSanam hitAya prANinAm-iti An-iti BagavataH- jQAna-darSanam hitAya prANinAm-iti vidarSitam hitAya prANinAm-iti vidarSitam hitAya prANinAm-

It affects about 2% of all the sentences. There is a small chance that the system is still on it's way to convergence and this happens because it hasn't trained enough. At least the overall precision is still slightly improving by about 0,5%/10h. However this improvement does not take place with regard to the gibberish-sentences as far as I can observe.

I found a simple hack to solve the problem. Since in this scenario the number of tokens in the input sequence and the output sequence are expected to be totally identical, I just run a postprocessing script that removes the unwanted tokens. Works. However I wonder whether there can be a more elegant solution to this problem.

martinpopel commented 6 years ago

the number of tokens in the input sequence and the output sequence are expected to be totally identical

If this is really the case (and the whole task is just about inserting spaces), then it would be better to treat the task as sequence labeling (binary: split vs. no-split) rather than sequence-to-sequence translation. Then the output would be forced to have the same length as the input. You would use just transformer_encoder followed by a softmax (and no decoder).

However, Sanskrit tokenization is complicated because of sandhi, so you cannot expect the same number of tokens on the output (no matter if the tokens are characters or subwords) because the possible tokenizations differ in the number of tokens, for example kurvannapnoti can be tokenized either as kurvan apnoti or as kurvan na apnoti. Even here you could transform the task to sequence labeling, but not binary (you would probably need a grammar, e.g. Sanskrit Heritage Reader to list all possible segmentations in a way that could be treated as multi-class sequence labeling).

sebastian-nehrdich commented 6 years ago

Thank you for your reply! We use a trick on our training data to tackle the problem with the tokens. The training set actually contains dashes where Sandhi has been applied, so the number of tokens remains the same on both sides (tathāgata becomming tathā-gata for example). I am a little bit skeptical about using sequence labeling in combination with a grammar. Inria is a great resource, however their grammar only covers a fraction of the language and such an approach will always remain limited. Especially if one wants to segment non-classical material such as Buddhist or Śaiva-Tantras where grammar and orthography can diverge widely. The current state of the art is seq2seq with a character-based RNN, however I am sure the transformer can do better. Actually with a vocab of 5k and the hack from above it is just 1% short of the state of the art and still converging, so might beat it. I am optimistic that proper training with an even smaller vocab will get things even further. Going down to character-level might yield the best result, yet I am worried about training time since our deadline is not too far away. But I gonna give it a try anyway. A first attempt just gave the impression that it indeed will take a long time to converge.

martinpopel commented 6 years ago

OK. I am no expert on Sanskrit, I just though it is easy to list all possible segmentations (because the grammar describes a fixed number of rules how to join individual words into a sequence) and the difficult job is to select the correct segmentation (in a given context). Of course, if there are non-classical materials for which the classical list of possible segmentations would not include the correct segmentation, then my approach is not applicable.

The training set actually contains dashes where Sandhi has been applied,

And can you detect this reliably even for non-classical texts? And what do you do with the examples as kurvannapnoti? Do you include double dash and allow an empty word on the output? (I know this is off-topic here, sorry, I just wonder...)

Back on the T2T topic. It would be interesting to explore why the gibberish output starts increasing after a couple of days of training. I have seen gibberish output in standard translation, when the source sentence was too long (longer than max_length, i.e. the maximum sentence length used in training) and I increased the beam search alpha in order to make the translations longer (they were too short). I am not sure if this is relevant for your case though. Anyway, it may be worth to play with --decode_hparams="alpha=$ALPHA,extra_length=$EXTRA_LENGTH" (the defaults are ALPHA=0.6, EXTRA_LENGTH=100), the maximal decoded length is source_length+EXTRA_LENGTH.

sebastian-nehrdich commented 6 years ago

Thank you for the hint with the decoding params, I will certainly give that a try. It might really be the case that the network did not properly converge yet and I just underestimated how much slower things go when the size of the vocabulary is decreased. At least now at about 60h the accuracy on test is still steadily increasing. The amount of gibberish is actually rather random, sometimes it is higher, one hour later it can be much lower, then high again. At the same time the training loss is also still jumping around to some degree, so I assume that it will just take it's time.

Regard the question on Sandhi: We created our training set synthetically from the only available bigger POS-tagges corpus. And our algorithm inserts the '-' in case were word-fusion should happen (for example kurvannāpnoti -> kurvan-na-āpnoti) and a '- ' in case where the words remain apart, but characters change (kṛṣṇo 'śvaḥ -> kṛṣṇaḥ- aśvaḥ). Since the data has been created synthetically the correspondence between the number of tokens is always 1:1.

I think your thought about using a grammar is very natural and people have worked with that approach in the past. It works on the classical language where ressources are good, but as soon as the material is a little bit less standard it becomes very difficult. Additionally to this Sanskrit is highly under-resourced, so we don't even have dictionaries that cover the entire classical language and for grammar inria is still limited (and Sanskrit on the other hand very productive). Not to speak of Buddhist or tantric material. character-based seq2seq-models seem however to be well capable to capture the general rules that are important for good Sandhi segmentation (e.g. character of the stem of the word, number, gender). There are also cases where gender and further grammatical information of remote words is necessary to do the right segmentation. This is why expanding context to the whole phrase shows a high boost of performance. I expect the transformer to beat the SOA in this field. Regarding the ability of this approach to translate to different genres, it works suprisingly well. While our model is trained on classical Sanskrit, it achieves comparable performance on Buddhist Sanskrit, even though the vocabulary is completely different. I assume that this is possible because character based seq2seq is capable to capture the underlying linguistic rules that are relevant.