lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.79k stars 417 forks source link

[Minor; noob question] Uniform distribution instead of normal #232

Open p0p4k opened 10 months ago

p0p4k commented 10 months ago

From the paper image

https://github.com/lucidrains/x-transformers/blob/90cef69e272f74756a0e5aa1ddd4523c0a23e49a/x_transformers/autoregressive_wrapper.py#L274-L280

I am still trying to understand the code,

rand = torch.randn(inp.shape, device = x.device) ---> creates a random array of normal dist number (0,1)
rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out  ---> makes first <bos> token unmaskable; will be smallest p for topk
num_mask = min(int(seq * self.mask_prob), seq - 1) ---> we need to mask each token with a mask_prob probability == we can just choose randomly mask_prob ratio of numbers from the token array AND it should never exceed (seq-1) number of tokens
indices = rand.topk(num_mask, dim = -1).indices ---> topk of the random numbers are chosen to be masked (so, shouldn't this be uniform distribution according to the paper?)
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool() --->  creates a boolean mask

I have 2 questions : (1) Will there ever be a case where seq-1 is bigger than int(seq * self.mask_prob) if we already have asserted the mask_prob is always <1. earlier in the code? (2) We are masking with a probability value, doesnt it mean sometimes the model might get to see more than (1-mask_prob) tokens? But here we force the ratio throughout? And then does using normal vs uniform make any big difference? Thanks!

shuishida commented 5 days ago

I think these are good questions:) The current code makes sure that a fixed percentage (mask_prob) of tokens are masked for every sequence, whereas the original paper seem to probabilistically mask mask_prob% of tokens in expectation (so the number of masked tokens can vary sequence by sequence).

My hunch is that the latter is a more general augmentation compared to the former, so may be more robust, but I haven't tested this hypothesis.