Open Atze00 opened 2 years ago
@Atze00 oh yes, although, i have found that combining absolute positional embedding and ALiBi works well and doesn't disrupt the sequence extrapolation!
i may keep it the way it is so others can corroborate that
Thanks for the reply, that's both interesting and counter-intuitive. In my case, it would cause some unwanted behavior. In particular, using the Transformer-XL with recurrence, the output results would change if I were to change the length of the segment in input. E.g.
seg = torch.randint(0, 20000, (1, 512))
logits1, mems1 = model_xl(seg[:,:256], return_mems = True)
logits2, mems2 = model_xl(seg[:,256:], mems = mems1, return_mems = True)
logits_final, mems1 = model_xl(seg, return_mems = True)
print(torch.equal(logits_final, logits2))
This script would output False.
I also wanted to suggest to use pre_norm=False
in the Transformer-XL, otherwise the above example would output False as well.
An option I would be ok with is to change these parameters in the Transformer-XL example in the README.
To me, these details are important, because evaluating with different segment length would change drastically the accuracy of the model. I have already changed these parameters in my code, but I feel like there are some good reason to change so that people don't have to deal with the same problems. I hope I'm not bothering you with details.
@Atze00 ohhh got it, for transformer-xl it would be a big no-no, you are correct
@Atze00 what do you think about a no_abs_pos_emb
flag? ugh, this is getting confusing
Isn't the flag already present? use_pos_emb
should work fine.
This definition for the transformer-XL should work also for the ALiBi positional encoding:
model_xl = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 512,
max_mem_len = 2048,
use_pos_emb = False,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
rel_pos_bias = True,
pre_norm = False
)
)
@Atze00 oh yea, that works 👍
Hi. To the best of my understanding, this line of code should be like this:
https://github.com/lucidrains/x-transformers/blob/a9de3a837ae69c917d8e26e71b967f750be6a1d2/x_transformers/x_transformers.py#L675