gaussalgo / adaptor

ACL 2022: Adaptor: a library to easily adapt a language model to your own task, domain, or custom objective(s).
MIT License
25 stars 4 forks source link

`TextClassification` objective sometimes misaligns tokens and wordpieces #25

Closed Witiko closed 2 years ago

Witiko commented 2 years ago

The following code indicates that TextClassification sometimes fails to align tokens and wordpieces:

>>> text = 'Ce. Jif. Cebivskym z Záduba a Kat. z Hostouné о odkaz . . Janem z Pernsteina a Boh. Cernfnem o rukojemstvi . Lad. z Boskovic a Boh. Cerninem o rukojemstvi . . Han. Trojem kupcem a Mat, Libikem z Radovesic o réeni . Jindř., Kulem z Véfic a Zikm. Pétipeskym z Kr. Dvoru o koné . Mik. postiihaem z Budějovic a Jindř. Sudlicem z Jivovice o dluh Jiř. z Puchova a Kun. Sertyngrem z Sertynge o koné . Krist. Talkenberkem a Janem Blektou o vložení včna do desk Jif. ze Stranec a Katef. Kozlovou o základ propadeny . . , Arn. z Drasova a Linh, Nekáem z Landeku o jistinu . Mik. Kouhou a Václ. z Dédibab o dluh . 654 2087. 2088. 2089. 2090. 2091. 2092. 2093. 2094. 2095. 2096. 2097. 2098. 2099. 2100. 2101. 2102. 2103. 2104. 2105. 2106. 2107. 2108. 2109. 2110. 2111. 2112. 2113. 2114. 2115. 2116. 2117. 2118. 2119. 2120. 2121. 2122. 2123. 2124. 2125. 2126. 2127. 2128. 2129. 2130. 2131. 2132. 2133. 2134. 2135. 2136. . list, . list. . list. . list. . list, . list. . list. . dub. . dub.'
>>> labels = 'O O O O O O O O O O O O O O O O O O O O O O O O LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O'
>>>
>>> assert len(text.split()) == len(labels.split())
>>>
>>> from adaptor.lang_module import LangModule
>>> from adaptor.objectives.classification import TokenClassification
>>> 
>>> lang_module = LangModule('xlm-roberta-base')
>>> objective = TokenClassification(lang_module, batch_size=1, texts_or_path=[text], labels_or_path=[labels])
>>> list(objective._wordpiece_token_label_alignment([text], [labels]))

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/***/adaptor/objectives/classification.py", line 52, in _wordpiece_token_label_alignment
    next_token = tokens[0]

IndexError: list index out of range

After removing all tokens (and corresponding labels) that only contain non-alphanumeric characters (re.fullmatch(r'\W+', token, re.UNICODE)), the problem disappears, which indicates that the sole . and , tokens might be the source of the issue.

We can guard against the IndexError by setting the last artificial token to None instead of wordpieces[-1]: https://github.com/gaussalgo/adaptor/blob/d5b64f353ec0314dfbcf92a426b97bc411fbf4d3/adaptor/objectives/classification.py#L41-L43 This should ensure that we will never consume the last articifial token by accident if we somehow get ahead of wordpieces during the alignment. Of course, we should still investigate why the misalignment happens and fix it, so that we don't silently feed garbage to the model!

Tasks

Witiko commented 2 years ago

@stefanik12 Sorry for the lack of the M in my MWE.