huggingface / datatrove

Freeing data processing from scripting madness by providing a set of platform-agnostic customizable pipeline processing blocks.
Apache License 2.0
2.06k stars 150 forks source link

Minhash mersenne hashing overflow issues. #198

Open Apsod opened 6 months ago

Apsod commented 6 months ago

Seeing as this repo inherits lots of code from https://github.com/ekzhu/datasketch, it should be noted that the implementation of mersenne prime hashing used in both repos causes overflows, and potentially more hash collisions than intended:

SEED = 0x5eed
_mersenne_prime_py = (1 << 61) - 1
_mersenne_prime = np.uint64(_mersenne_prime_py)
num_hashes=2
N=3

gen = np.random.RandomState(SEED)
a, b = (
    gen.randint(1, _mersenne_prime, dtype=np.uint64, size=(1, num_hashes)),
    gen.randint(0, _mersenne_prime, dtype=np.uint64, size=(1, num_hashes)),
)

shingles = gen.randint(0, (1<<32), dtype=np.uint64, size=(N, 1))

def h1(a, b, shingles):
    # Numpy uint64, overflows
    return (shingles * a + b) % _mersenne_prime

def h2(a, b, shingles):
    # Native python, no overflow.
    rows = []
    for sj in shingles[:, 0].tolist():
        rows.append([
            (sj * ai + bi) % _mersenne_prime_py
            for ai, bi in zip(a[0].tolist(), b[0].tolist())
        ])

    return np.array(rows)

print(h1(a, b, shingles) == h2(a, b, shingles))
# False False False ... 
guipenedo commented 6 months ago

While the resulting numbers are not the same, this does not seem conclusive proof to me that you get more collisions, you will simply get them on different values as in the end you will always have the modulo squashing everything to the same value range for both cases The pure python implementation has the (very big) disadvantage of being considerably slower

Apsod commented 6 months ago

I agree that this is not conclusive proof of more collisions, however, it seems like a bug to me to purportedly do affine transforms modulo mersenne primes, when this is not what the code is doing. Currently, the implementation is doing the following:

def h3(a, b, shingles):
    # Native python, simulating overflow, equals to h1. 
    rows = []
    for sj in shingles[:, 0].tolist():
        rows.append([
            ((sj * ai + bi) % (1<<64)) % _mersenne_prime
            for ai, bi in zip(a[0].tolist(), b[0].tolist())
        ])

    return np.array(rows)

At which point I doubt the whole mersenne prime field serves any purpose, and you can just go mod (1<<64), i.e. no mod (or mersenne primes) at all.