In transformer-xl/pytorch/mem_transformer.py,
I found the argument order of _update_mems ftn(member of class MemTransformerLM) is wrong!
< See difference >
619th line :
def _update_mems(self, hids, mems, qlen, mlen):
733th line :
new_mems = self._update_mems(hids, mems, mlen, qlen)
I tested the code when qlen == mlen, so the code worked without any problem,
but the code should be corrected, for the case when qlen != mlen.
Thanks!
In transformer-xl/pytorch/mem_transformer.py, I found the argument order of _update_mems ftn(member of class MemTransformerLM) is wrong! < See difference > 619th line : def _update_mems(self, hids, mems, qlen, mlen): 733th line : new_mems = self._update_mems(hids, mems, mlen, qlen)
I tested the code when qlen == mlen, so the code worked without any problem, but the code should be corrected, for the case when qlen != mlen. Thanks!