ekzhu / datasketch

MinHash, LSH, LSH Forest, Weighted MinHash, HyperLogLog, HyperLogLog++, LSH Ensemble and HNSW
https://ekzhu.github.io/datasketch
MIT License
2.56k stars 293 forks source link

Poor default args in MinHashLSH? #200

Open kuk opened 1 year ago

kuk commented 1 year ago

I prepare 10 synthetic examples.

import random

values = []
queries = []

count = 1
for _ in range(10):

    value = []
    for _ in range(100):
        value.append(count)
        count += 1

    query = value[:]
    indexes = list(range(100))
    random.shuffle(indexes)
    for index in indexes[:3]:
        query[index] = -query[index]

    values.append(value)
    queries.append(query)

for items in [values, queries]:
    for item in items:
        for index in range(len(item)):
            item[index] = str(item[index]).encode('ascii')

Value and query in each pair have Jaccard > 0.9

index = 0
query = queries[index]
value = values[index]

print(query)
print(value)
print(len(set(query) & set(value)) / len(set(query) | set(value)))

[b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'10', b'11', b'12', b'13', b'14', b'15', b'16', b'17', b'18', b'19', b'20', b'21', b'22', b'23', b'24', b'25', b'26', b'27', b'28', b'29', b'30', b'31', b'32', b'33', b'34', b'35', b'36', b'37', b'38', b'39', b'40', b'41', b'42', b'43', b'44', b'45', b'46', b'-47', b'48', b'49', b'50', b'51', b'52', b'53', b'54', b'55', b'-56', b'57', b'58', b'59', b'60', b'61', b'62', b'63', b'64', b'65', b'66', b'67', b'68', b'69', b'70', b'71', b'72', b'73', b'74', b'75', b'76', b'77', b'78', b'-79', b'80', b'81', b'82', b'83', b'84', b'85', b'86', b'87', b'88', b'89', b'90', b'91', b'92', b'93', b'94', b'95', b'96', b'97', b'98', b'99', b'100']
[b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'10', b'11', b'12', b'13', b'14', b'15', b'16', b'17', b'18', b'19', b'20', b'21', b'22', b'23', b'24', b'25', b'26', b'27', b'28', b'29', b'30', b'31', b'32', b'33', b'34', b'35', b'36', b'37', b'38', b'39', b'40', b'41', b'42', b'43', b'44', b'45', b'46', b'47', b'48', b'49', b'50', b'51', b'52', b'53', b'54', b'55', b'56', b'57', b'58', b'59', b'60', b'61', b'62', b'63', b'64', b'65', b'66', b'67', b'68', b'69', b'70', b'71', b'72', b'73', b'74', b'75', b'76', b'77', b'78', b'79', b'80', b'81', b'82', b'83', b'84', b'85', b'86', b'87', b'88', b'89', b'90', b'91', b'92', b'93', b'94', b'95', b'96', b'97', b'98', b'99', b'100']
0.941747572815534

I insert all values in MinHashLSH, use default settings. For every query expect exactly one value. But in 4 / 10 cases get no results.

from datasketch import (
    MinHash,
    MinHashLSH
)

lsh = MinHashLSH(
    threshold=0.9,
)
for index, value in enumerate(values):
    minhash = MinHash(128)
    minhash.update_batch(value)
    lsh.insert(index, minhash)

for query in queries:
    minhash = MinHash(128)
    minhash.update_batch(query)
    print(lsh.query(minhash))

    for index, value in enumerate(values):
        if len(set(value) & set(query)) / len(query) > 0.9:
            print(index)
            break

    print()

[]
0

[1]
1

[2]
2

[]
3

[]
4

[5]
5

[6]
6

[]
7

[8]
8

[9]
9

I change weights and get correct results

lsh = MinHashLSH(
    threshold=0.9,
    weights=(0.1, 0.9)
)

...

[0]
0

[1]
1

[2]
2

[3]
3

[4]
4

[5]
5

[6]
6

[7]
7

[8]
8

[9]
9

Is it an expected behavior? Maybe change default threshold or weights?

ekzhu commented 1 year ago

I think you brought up a great point. There needs to be some work in improving the hyper-parameter optimization code (takes the weights and assigns b and r for the index, see https://github.com/ekzhu/datasketch/blob/master/datasketch/lsh.py#L22). The current one is both data-agnostic and not tuned toward any specific recall requirement.

Ideas welcome!