aiidateam / disk-objectstore

An implementation of an efficient "object store" (actually, a key-value store) writing files on disk and not requiring a running server
https://disk-objectstore.readthedocs.io
MIT License
15 stars 8 forks source link

Efficient intersections and getting all elements #93

Closed giovannipizzi closed 3 years ago

giovannipizzi commented 4 years ago

When comparing two containers to decide what to send on the other side, it becomes important to be able to check what is (or is not) already on the destination.

The following code can check the content of two (sorted and unique) iterators and return who has the item, iterating only once on both in alternation.

Also, this shows how session.execute() returns a true iterator without pre-loading everything in memory.

Tasks:

Here is the function, and below the output (IMPORTANT: THE FUNCTION BELOW HAS A COUPLE OF BUGS, A CORRECT IMPLEMENTATION HAS BEEN PUT IN THE CODE IN 38471b6):

#!/usr/bin/env python

import os
import time
import hashlib
import tqdm
from collections import OrderedDict

from disk_objectstore.models import Obj, Base
from disk_objectstore.utils import PackedObjectReader, get_hash

from sqlalchemy import create_engine, event
from sqlalchemy.sql import func
from sqlalchemy.orm import sessionmaker

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, Boolean, BLOB

BaseAr = declarative_base()  # pylint: disable=invalid-name,useless-suppression

output_sqlite_file = '/scratch/TEST-DISK-OBJECTSTORE/tmp-db/archive.idx'

class ObjBlob(BaseAr):  # pylint: disable=too-few-public-methods
    __tablename__ = 'db_objblob'

    id = Column(Integer, primary_key=True)  # pylint: disable=invalid-name

    # Important: there are parts of the code that rely on the fact that this field is unique.
    # If you really do not want a uniqueness field, you will need to adapt the code.
    hashkey = Column(String, nullable=False, unique=True, index=True)
    compressed = Column(Boolean, nullable=False)
    size = Column(Integer, nullable=False)  # uncompressed size; if uncompressed, size == length
    data = Column(BLOB)    

def get_session(which):
    if which == 'read':
        engine = create_engine('sqlite:////scratch/TEST-DISK-OBJECTSTORE/tmp-db/packs.idx')
    elif which == 'write':
        engine = create_engine('sqlite:///{}'.format(output_sqlite_file))
    else:
        raise ValueError('')

    # For the next two bindings, see background on
    # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
    @event.listens_for(engine, 'connect')
    def do_connect(dbapi_connection, connection_record):  # pylint: disable=unused-argument,unused-variable
        """Hook function that is called upon connection.
        It modifies the default behavior of SQLite to use WAL and to
        go back to the 'default' isolation level mode.
        """
        # disable pysqlite's emitting of the BEGIN statement entirely.
        # also stops it from emitting COMMIT before any DDL.
        dbapi_connection.isolation_level = None
        # Open the file in WAL mode (see e.g. https://stackoverflow.com/questions/9671490)
        # This allows to have as many readers as one wants, and a concurrent writer (up to one)
        # Note that this writes on a journal, on a different packs.idx-wal,
        # and also creates a packs.idx-shm file.
        # Note also that when the session is created, you will keep reading from the same version,
        # so you need to close and reload the session to see the newly written data.
        # Docs on WAL: https://www.sqlite.org/wal.html
        cursor = dbapi_connection.cursor()
        cursor.execute('PRAGMA journal_mode=wal;')
        cursor.close()

    # For this binding, see background on
    # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
    @event.listens_for(engine, 'begin')
    def do_begin(conn):  # pylint: disable=unused-variable
        # emit our own BEGIN
        conn.execute('BEGIN')

    # Create all tables in the engine. This is equivalent to "Create Table"
    # statements in raw SQL.
    if which == 'write':
        BaseAr.metadata.create_all(engine)
        BaseAr.metadata.bind = engine
    else:
        # Bind the engine to the metadata of the Base class so that the
        # declaratives can be accessed through a DBSession instance
        Base.metadata.bind = engine

    # We set autoflush = False to avoid to lock the DB if just doing queries/reads
    DBSession = sessionmaker(  # pylint: disable=invalid-name
        bind=engine, autoflush=False, autocommit=False
    )
    session = DBSession()

    return session

from enum import Enum

class Location(Enum):
    """Enum that describes if an element is only on the left or right iterator, or on both."""
    LEFTONLY = -1
    BOTH = 0
    RIGHTONLY = 1

def detect_where(left_iterator, right_iterator):
    """Generator that loops in alternation (but only once each) the two iterators and yields an element, specifying if
    it's only on the left, only on the right, or in both.

    .. note:: IMPORTANT! The two iterators MUST return unique and sorted results.
    This function will check and raise a ValueError if it detects non-unique or non-sorted elements.
    HOWEVER, this exception is raised only at the first occurrence of the issue, that can be very late in the execution,
    so if you process results in a streamed way, please ensure that you pass sorted iterators.
    """
    left_exhausted = False
    right_exhausted = False

    try:
        last_left = next(left_iterator)[0]
    except StopIteration:
        left_exhausted = True

    try:
        last_right = next(right_iterator)[0]
    except StopIteration:
        right_exhausted = True

    now_left = True
    if left_exhausted or last_left > last_right:
        now_left = False # I want the 'current' (now) to be behind or at the same position of the other at any time

    while not(left_exhausted and right_exhausted):
        advance_both = False
        if now_left:
            if right_exhausted:
                yield last_left, Location.LEFTONLY
            else:
                if last_left == last_right:
                    # They are equal: add to intersection and continue
                    yield last_left, Location.BOTH
                    # I need to consume and advance on both iterators at the next iteration
                    advance_both = True
                elif last_left < last_right:
                    # the new entry (last_left) is still smaller: it's on the left only
                    yield last_left, Location.LEFTONLY
                else:
                    # the new entry (last_left) is now larger: then, last_right is only on the right
                    # and I switch to now_right
                    yield last_right, Location.RIGHTONLY
                    now_left = False
        else:
            if left_exhausted:
                yield last_right, Location.RIGHTONLY
            else:
                if last_left == last_right:
                    # They are equal: add to intersection and continue
                    yield last_right, Location.BOTH
                    # I need to consume and advance on both iterators at the next iteration
                    advance_both = True
                elif last_left > last_right:
                    # the new entry (last_right) is still smaller: it's on the right only
                    yield last_right, Location.RIGHTONLY
                else:
                    # the new entry (last_right) is now larger: then, last_left is only on the left
                    # and I switch to now_left
                    yield last_left, Location.LEFTONLY
                    now_left = True

        # When we are here: if now_left, then last_left has been inserted in one of the lists;
        # if not now_left, then last_right has been insterted in one of the lists.
        # If advance both, they both can be discarded. So if I exhausted an iterator, I am not losing
        # any entry.
        if now_left or advance_both:
            try:
                new = next(left_iterator)[0]
                if new <= last_left:
                    raise ValueError(
                        "The left iterator does not return sorted unique entries, I got '{}' after '{}'".format(
                            new, last_left
                        ))
                last_left = new
            except StopIteration:
                left_exhausted = True
                now_left = False

        if not now_left or advance_both:
            try:
                new = next(right_iterator)[0]
                if new <= last_right:
                    raise ValueError(
                        "The right iterator does not return sorted unique entries, I got '{}' after '{}'".format(
                            new, last_right
                        ))
                last_right = new
            except StopIteration:
                right_exhausted = True
                now_left = True

session_read = get_session('read')
session_write = get_session('write')

t = time.time()
count_read = list(session_read.execute("SELECT count(*) FROM db_object"))[0][0]
count_write = list(session_write.execute("SELECT count(*) FROM db_objblob"))[0][0]
print(f"COUNT BOTH (read={count_read}, write={count_write}", time.time() - t)

# IMPORTANT: they must be ordered and unique
old_hashkeys_query = session_read.execute("SELECT hashkey FROM db_object ORDER BY hashkey")
new_hashkeys_query = session_write.execute("SELECT hashkey FROM db_objblob ORDER BY hashkey")

t = time.time()
intersection = set()
left_only = set()
right_only = set()
set_mapping = {
    Location.LEFTONLY: left_only,
    Location.BOTH: intersection,
    Location.RIGHTONLY: right_only,
}

bar = tqdm.tqdm(total=count_read+count_write)
for hashkey, location in detect_where(old_hashkeys_query, new_hashkeys_query):
    set_mapping[location].add(hashkey)
    # We need to count 2 if the item is on both sides
    bar.update(2 if location == Location.BOTH else 1)
bar.close()
print("LIST BOTH UUIDS in", time.time() - t)

print("LEFT (OLD): ", len(left_only))
print("INTERSECT:  ", len(intersection))
print("RIGHT (NEW):", len(right_only))
print("SUM OF THE THREE=", len(intersection) + len(left_only) + len(right_only))
print("SUM OF ENTRIES ON LEFT AND RIGHT=", count_read + count_write)

print()
print("Performing final checks...")
# No duplicates
assert len(left_only) == len(set(left_only))
assert len(intersection) == len(set(intersection))
assert len(right_only) == len(set(right_only))
all_elements = list(left_only) + list(right_only) + list(intersection)
assert len(all_elements) == len(set(all_elements))

# Everything returned - note that those in the intersection are in both so I need to count twice
assert len(left_only) + 2*len(intersection) + len(right_only) == count_read + count_write
print("DONE.")

OUTPUT (run on a DB of 6.7M nodes, and a subset of it with ~200k nodes):

 100%|█████████████████████████████████████████████████████| 6920262/6920262 [00:20<00:00, 341515.10it/s]
LIST BOTH UUIDS in 20.2684428691864
LEFT (OLD):  6509354
INTERSECT:   205454
RIGHT (NEW): 0
SUM OF THE THREE= 6714808
SUM OF ENTRIES ON LEFT AND RIGHT= 6920262

Performing final checks...
DONE.
giovannipizzi commented 4 years ago

Note that not putting the results into sets only saves little time (~18s instead of 20s), most of the time comes from getting the whole data from the DB.

giovannipizzi commented 4 years ago

Note: asking back from the DB not the whole hash key but only the first few characters (say 8), while it will definitely save memory, it does not help in speeding up the data retrieval, it only changes from ~13.0 to 12.5s.

t = time.time()
count_read = list(session_read.execute("SELECT count(*) FROM db_object"))[0][0]
count_write = list(session_write.execute("SELECT count(*) FROM db_objblob"))[0][0]
print(f"COUNT BOTH read={count_read} write={count_write}", time.time() - t)

t = time.time()
all_uuids = list(tqdm.tqdm(session_write.execute("SELECT hashkey FROM db_objblob ORDER BY hashkey"), to\
tal=count_write))
print("LIST ALL {} NEW UUIDS in".format(len(all_uuids)), time.time() - t)

NUMCHAR = 8
t = time.time()
all_uuids = list(tqdm.tqdm(session_read.execute(
    "SELECT substr(hashkey,1,{}) FROM db_object ORDER BY hashkey".format(NUMCHAR)), total=count_read))
print("LIST ALL {} OLD UUIDS (ONLY FIRST {} CHARS) in".format(len(all_uuids), NUMCHAR), time.time() - t\
)
print(all_uuids[:10])
assert all(len(uuid[0]) == NUMCHAR for uuid in all_uuids)

Output:

COUNT BOTH read=6714808 write=205454 0.8465695381164551
100%|███████████████████████████████| 205454/205454 [00:01<00:00, 204535.22it/s]
LIST ALL 205454 NEW UUIDS in 1.0100953578948975
100%|█████████████████████████████| 6714808/6714808 [00:12<00:00, 535991.18it/s]
LIST ALL 6714808 OLD UUIDS (ONLY FIRST 8 CHARS) in 12.544926643371582
[('0000000b',), ('0000029e',), ('00000571',), ('000006af',), ('00000815',), ('00000a24',), ('00000c2a',), ('00000c7e',), ('00000ec3',), ('0000113c',)]
100%|█████████████████████████████| 6714808/6714808 [00:12<00:00, 529509.90it/s]
LIST ALL 6714808 OLD UUIDS in 13.139374256134033
[('0000000b2775f652d71b1ec66477627d81c38ec65b4572f7af3de6fe103c4cab',), ('0000029ea9ed78cfa80c9da7f982657d6c4d85fa8646f0413a15c410462f7973',), ('00000571f966ac015fec7704982dbe9d43c390b1e532fada16f15a663f454cc2',), ('000006af96834385912af6cc93abdfa779d1da9173446c70cd56ca9ca17df8a0',), ('000008154aa9885b1b0e4de235c0b4232a0deb0107fdefd58281b8af0cfe0009',), ('00000a24c0fe196d892dc8c0a9930f81093d39902bc41e21266cb6ec8d001549',), ('00000c2a1a3c980a7ff90061abe63f31bc0e16bc18b815962eea151711aa3eac',), ('00000c7e4f1332f601c53ceaddb321d5505865756a591950da3ab683114cbdf8',), ('00000ec3686d0f887831830e15ce2ffe3d64bbb57fa071e2a0bc9afb196a8b15',), ('0000113c539a849568ac8eed03f261a5b8c74c3fdb093f17bc9c61d33907635e',)]