lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

Question: masking in token shifting #208

Open pfeatherstone opened 8 months ago

pfeatherstone commented 8 months ago

In token shifting, you explicitly zero out masked items:

https://github.com/lucidrains/x-transformers/blob/5ce82c9a9b4404e2405957a97ddc78abdbc60885/x_transformers/x_transformers.py#L554-L555

Is this strictly necessary? Since we are shifting right, the shifted tokens should be valid right? Or is this accounting for items masked on the left? In which case you might be shifting and adding with an invalid token?

I noticed that RecurrentMemoryTransformer didn't do this:

https://github.com/lucidrains/recurrent-memory-transformer-pytorch/blob/d45ef72a40324c6224ffacb890d5593a69db73de/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py#L65-L70

Hence why I'm asking if it's strictly necessary.

lucidrains commented 8 months ago

@pfeatherstone i think i allow for bidirectional shifting, maybe that's why

i can check later