lucidrains / reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
MIT License
2.1k stars 254 forks source link

Why mutiply seq_len(4096)*buckets + ticker%4096 before sorting? #91

Closed muiPomeranian closed 4 years ago

muiPomeranian commented 4 years ago

First of all appreciate your big works!

But why do we multiply 4096 to buckets? Bucket=[16,4x4096] Ticker%4096=[16,4x4096]

I can barely assume that ticker%4096 represents the location or position of chars from 0 to 4095 for each n round hash.

But why do we sort 4096xbuckets+tickers%4096?

I thought we would sort the bucket only....(even without multiplying 4096).

Also, why do we add torch.range(0,3)*64 to each rows of buckets? It seems we want to differentiate or to avoid the duplicate numbers of each rows or batches after bucketing but i wonder why this is necessary...

Thanks a lot!!

lucidrains commented 4 years ago

Hi @seungbochoi !

So most of this code was faithfully reproduced from the official Trax implementation, so it was a learning experience for me as well. I think it is best to step through an example, say with 2 buckets

# abcdefgh - token sequence
# 11212211 - buckets from LSH hashing

If you were to sort by the buckets, it would be not deterministic, because there are identical bucket values

# abcdefgh - token sequence
# 11212211 - buckets from LSH hashing
# 08 08 16 08 16 16 08 08 - buckets * seqlen (8)  - still not unique
# 08 09 18 11 20 21 14 15 - buckets * seqlen + ticker % seqlen

Now you have all unique values for deterministic sorting!

muiPomeranian commented 4 years ago

Ah I see! So the whole purpose is to obtain the deterministic sorting.

1st Q Is it to maintain the local sequence orders? e.g. from your example, bucket 1: a,b,d,g,h when we sort, we want to keep these orders as it is (a->b->(c)->d...etc)

If so, does it affect the model performance? Why do we need to keep this order ?

2nd Q If you have some time, I hope you could answer my 2nd question too!(huge appreciate !) why do we add torch.range(0,3)*64 to each rows of buckets? It seems we want to differentiate or to avoid the duplicate numbers of each rows or batches after bucketing but i wonder why this is necessary... e.g. buckets.shape = [16,4,4096], offsets=[[[0,64,128,196]]] buckets = torch.reshape(buckets+offsets, (batch_size, -1,)) if we only care about 1st batch of multi-head, there are 4 rounds of hash thus, there should be 16 of below bucket(matrix) [ [1, 5, 3, 3,...... 2, 53], [7,4,3,2,... 64, 53, 2], [2,33,5,3,.......2, 53], [2,35,1,64,53,......,2] ]

but we do bucket + offset and then bucket+offset = [ [1, 5, 3, 3,...... 2, 53], [7+64, 4+64, 3+64, 2+64 ,... 64+64, 53+64, 2+64], [2+128, 33+128, 5+128, 3+128, .......2+128, 53+128], [2+196, 35+196, 1+196, 64+196, 53+196,......,2+196] ]

I could see that we want to preserve the original bucket value and also want to differentiate all rows to avoid having same numbers. But wonder why this is necessary step for us ..!

Thanks a lot !

lucidrains commented 4 years ago

Adding the offset is simply giving each bucket across the different hashing rounds their own unique bucket value, so that the tokens can be deterministically sorted.

The deterministic sort is necessary in a later step where the attention logits are unsorted. Since the unsorting step is not differentiable, you need the buckets_and_t to sort the gradients appropriately. https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reformer_pytorch.py#L407