MolSSI / QCFractal

A distributed compute and database platform for quantum chemistry.
https://molssi.github.io/QCFractal/
BSD 3-Clause "New" or "Revised" License
144 stars 47 forks source link

Batch record retrieval #742

Closed chrisiacovella closed 9 months ago

chrisiacovella commented 11 months ago

This issue relates to downloading large batches of records, such as fetching an entire dataset.

There are two main hurdles I've run into with the batch download: 1- Efficiency of the process 2- Loss of connection to the database during retrieval

Note, this issue relates to issues #740 and #741.

Here, I'm will outline some benchmarks and my attempted solutions.

Serial retrieval of records benchmarks

Currently if I want to fetch records I would do something similar to:

from qcportal import PortalClient
import time

client = PortalClient()
ds = client.get_dataset(dataset_type='singlepoint', dataset_name='QM9')

#grab the entry names
entry_names = ds.entry_names

# a barebones function to fetch relevant info from the database
def get_records(start, end):

    local_records = {}

    for record_name in entry_names[start:end]:
        temp = ds.get_entry(record_name).dict()
        temp2 = ds.get_record(record_name, specification_name='spec_2').dict()
        local_records[record_name] = [temp, temp2]

    return local_records

The performance for fetching a small number of records is good:

%time temp_records = get_records(start = 0, end = 10)
> CPU times: user 130 ms, sys: 32.5 ms, total: 163 ms
> Wall time: 3.66 s

Since this is a serial process, as might be expected, timing scales with the number of records

Based on this scaling, fetching the whole QM9 dataset (with ~133K records) would take about 13 hours (in practice this is about right).

Multithreaded retrieval bencharmks

I used concurrent.futures to wrap the get_records function above, to allow for concurrent multithreading. I'll note I chose concurrent.futures because it works in jupyter notebooks without needing to do any sort of "tricks". The code below is a pretty bare bones implementation of multithreaded retrieval. I'll note, this code will chunk together 25 records per thread (i.e., chunk_size) and allow for 48 concurrent threads (i.e., n_threads); these values seems to be pretty reasonable on my 8 core MacBook.

from concurrent.futures import ThreadPoolExecutor, as_completed

def run_threaded(max_records):
    range_list = []

    chunk_size = 25 
    n_threads = 48 # this seemed to be about the sweet spot on my machine

    for i in range(int(max_records/chunk_size)+1):
        start = i*chunk_size
        if start >= max_records:
            break
        end = (i+1)*chunk_size
        if end > max_records:
            end = max_records

        range_list.append([start, end])

    threads= []
    overall_records = {}
    with ThreadPoolExecutor(max_workers=n_threads) as executor:
        for min_max in range_list:
            threads.append(executor.submit(get_records, min_max[0], min_max[1]))

        for task in as_completed(threads):
            overall_records = {**overall_records, **task.result()}
    return overall_records

Benchmarks for fetching records:

Extrapolating to the whole qm9 dataset, it would take about <15 minutes to fetch the 133K records, which seems quite reasonable as compared to 13 hours.

Loss of database connectivity

If I were to up the number of records to, say 10000, at some point during the process I will get the following error:

PortalRequestError: Request failed: Bad Gateway (HTTP status 502)

Note, if I just turn off my wifi (to simulate an internet hiccup), I get the following error:

Could not connect to server https://api.qcarchive.molssi.org/, please check the address and try again.

I modifying the code above (shown below), to simply put the record retrieval in a try/except structure (inside a while loop that will retry a max of 1000 times). I'll note, in the code below, instead of recording the actual data fetched, i just keep track of the number of "failures" to connect to the server.

I had this code fetch 10000 records 3 times; the timing was very consistent at about 100 s.

Interestingly, the number of failures to connect didn't change much between the 3 runs: 405, 407, and 408. If it were just my internet being a bit flaky, I would have expected those to be less consistent. Since data is being broken up into 25 record chunks, that means 400 total threads, so the ~400 failures seems a little suspect . Digging a little into these numbers, each thread on the first call to the portal,ends up in the except statement. This confuses me because, if every single thread needs to reconnect initially, how does the first implementation of the threading, without the try/except of this would even work? (There must be something odd about the try/exception statement that is eluding me). Regardless, this means that 400 of those "failures" aren't real, so we are really dealing with 5, 7, and 8, i.e., <1 failure to connect for every 1000 record calls.

Issue #741 suggests using retry behind the scenes which seems to be a more user friendly approach and would allow for a bit smarter approach in terms of connection retries.

Creating multithreaded get_records() and 'get_entries()' functions in the qcportal would be quite beneficial.

def get_records_try(start, end):

    local_records = {}
    failures = 0 
    for i, record_name in enumerate(entry_names[start:end]):
        fetch_record = True
        attempt = 0
        max_iterations = 1000
        while fetch_record == True and attempt < max_iterations:
            try:
                temp = ds.get_entry(record_name).dict()
                temp2 = ds.get_record(record_name, specification_name='spec_2').dict()

                local_records[record_name] = [temp, temp2]
            except:
                time.sleep(float(attempt/max_iterations))
                client = PortalClient()
                ds = client.get_dataset(dataset_type='singlepoint', dataset_name='QM9')
                failures += 1
            else:
                fetch_record = False

            attempt += 1

    return failures

def run_threaded_safe(max_records):
    range_list = []

    chunk_size = 25 
    n_threads = 48 # this seemed to be about the sweet spot on my machine

    for i in range(int(max_records/chunk_size)+1):
        start = i*chunk_size
        if start >= max_records:
            break
        end = (i+1)*chunk_size
        if end > max_records:
            end = max_records

        range_list.append([start, end])

    threads= []
    failures = []
    with ThreadPoolExecutor(max_workers=n_threads) as executor:
        for min_max in range_list:
            threads.append(executor.submit(get_records_try, min_max[0], min_max[1]))

        for task in as_completed(threads):
            failures.append(task.result())
    print(len(range_list))
    print(len(threads))
    return failures
chrisiacovella commented 11 months ago

Multithreading breakdown performance

I ran a few benchmarks with different chunk sizes and threads on my 8 core and 32 core machines, mostly because I was curious how much these impacted the results I showed above. In terms of the number of records to fetch per thread, the sweet spot seems to be 5-10 records. In terms of the number of concurrent threads, 8-10 threads per core.

Fetching 1000 records on 8 core machine


chunk_size  n_threads   time    (s)
1       48      31.9    
2       48      26.2
5       48      13.3
10      48      13.7
25      48      16.2
50      48      27.6

chunk_size  n_threads   time    (s)
10      24      25.2
10      48      13.7
10      64      10.9
10      96      11.9

Fetch 2000 records on 8 core machine

chunk_size  n_threads   time     (s)
10      24      43.7
10      48      23.6
10      64      18.4
10      96      16.9
10      128     21.2

Fetch 2000 records on 32 core machine
chunk_size  n_threads   time     (s)
10      24      35.4
10      48      21.1
10      64      20.9    
10      96      19.8
10      128     14.5
10      196     11
10      200     9.81

fetch 3000 records on 32 core machine
chunk_size  n_threads   time     (s)
5       24      57.5    
5       48      31.2
5       64      27
5       96      26.1
5       128     21.6
5       196     17.7
5       256     15.5
5       288     15.2
5       320     15.4
5       384     11.2
5       416     10.9
5       512     11.5

fetch 3000 records on 32 core machine
chunk_size  n_threads   time     (s)
10      24      52.4    
10      48      30.6
10      64      25.7
10      96      24
10      128     19.2
10      196     14.4
10      256     14.1
10      288     15.1
10      300     15
``
bennybp commented 11 months ago

A few notes:

Getting records one-by-one will always be very slow because of various overheads. Datasets have an iterate_records function which will download records in batches behind the scenes. There is still an opportunity to optimize that further with something like a thread pool, but it is still much faster than getting records one-by-one.

The records obtained from iterate_records are stored/cached internally to the class as well.

The Bad gateway error is a little concerning. I will try to reproduce that. That points to an issue with handling the requests on my end, which shouldn't happen. Is that happening only when you use lots of threads?

Possibly related, but I did some investigation yesterday. We are indeed having some networking hardware issues that we are trying to pinpoint. So that might be adding some slowdown and variability to your requests as well.

chrisiacovella commented 11 months ago

iterate_records anditerate_entries is definitely a substantial improvement over just using the get_record and get_entry. Fetching the 133K records took just over 9 minutes (for completeness, the code snippet is below), so faster than the multithreaded calls to get_records/entries (which was ~15 minutes).

I'll note the bad gateway error would occur regardless of whether I used a threaded or non-threaded implementation. Even when using the iterate records, I am still getting a Bad gateway error roughly 1 out of ever 3 times I try to fetch the dataset.

local_records = {}

def iterate_records():

    for entry, record in zip(ds.iterate_entries(force_refetch=True), ds.iterate_records(specification_names=['spec_2'], force_refetch=True)):
        temp = entry.dict()
        temp2 = record
        local_entries[temp['name']] = temp
        local_records[temp2[0]] = temp2[2].dict()
        return local_records
bennybp commented 11 months ago

FYI, the networking issues here at VT seem to have been resolved. So hopefully things are generally faster. I will look more deeply into the Bad Gateway errors, but at least at the moment I couldn't reproduce.

chrisiacovella commented 11 months ago

I just ran a bunch of "tests" (essentially just running a few loops of grabbing the same dataset above) with no gateway errors.

bennybp commented 11 months ago

Ok that's funny, because now I am able to reproduce the Bad Gateway (from a Jupyter notebook running on a VM on the server, no less). So I will investigate that

bennybp commented 9 months ago

I think this has been resolved, but if you run into problems again let me know!

peastman commented 7 months ago

I'm running into this same problem. I've been trying all day to download data with the script at https://github.com/openmm/spice-dataset/tree/main/downloader. It invariably fails with the exception

  File "/home/peastman/workspace/spice-dataset/downloader/downloader.py", line 111, in <module>
    recs = list(dataset.iterate_records(specification_names=specifications))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peastman/miniconda3/envs/qcportal/lib/python3.11/site-packages/qcportal/dataset_models.py", line 972, in iterate_records
    self._internal_fetch_records(batch_tofetch, [spec_name], status, include)
  File "/home/peastman/miniconda3/envs/qcportal/lib/python3.11/site-packages/qcportal/dataset_models.py", line 716, in _internal_fetch_records
    record_info = self._client.make_request(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peastman/miniconda3/envs/qcportal/lib/python3.11/site-packages/qcportal/client_base.py", line 408, in make_request
    r = self._request(
        ^^^^^^^^^^^^^^
  File "/home/peastman/miniconda3/envs/qcportal/lib/python3.11/site-packages/qcportal/client_base.py", line 373, in _request
    raise PortalRequestError(f"Request failed: {details['msg']}", r.status_code, details)
qcportal.client_base.PortalRequestError: Request failed: Bad Gateway (HTTP status 502)

though only after working for anywhere from 20 minutes to 2.5 hours. I've tried from two different computers. One is at my home with a not great internet connection. The other is in a data center at Stanford with a very high speed connection. Both of them fail.

peastman commented 7 months ago

The downloads also are going ridiculously slowly, sometimes taking over an hour for a single dataset. Almost all the time is spent in just two lines:

    recs = list(dataset.iterate_records(specification_names=specifications))

and

    all_molecules = client.get_molecules([r.molecule_id for e, s, r in recs])
bennybp commented 7 months ago

The 502 errors are hard to debug, and something I've occasionally seen before. I see them in the logs but I don't see much more detail than that. I think it's related to some interaction between Traefik and gunicorn (where no workers are available), but I need to dig deeper. Unfortunately it is hard to reproduce.

I'm running the downloader script now and it seems to be chugging along. It could just be the amount of data being downloaded and being limited by cross-country bandwidth.

If you run the following (which just fetches all the records of the first dataset) how long does it take for you? For me it takes about 2 minutes.

time python3 -c "import qcportal as ptl;c=ptl.PortalClient('https://ml.qcarchive.molssi.org');ds=c.get_dataset_by_id(343);print(ds.name);recs=list(ds.iterate_records())"

The good-ish news is that I have been working on the local caching stuff, and that is almost ready. So this kind of stuff might become much easier in the future.

I do see that when getting records, there is a needlessly complicated query happening, although I doubt it affects it too much.

peastman commented 7 months ago
$ time python3 -c "import qcportal as ptl;c=ptl.PortalClient('https://ml.qcarchive.molssi.org');ds=c.get_dataset_by_id(343);print(ds.name);recs=list(ds.iterate_records())"
SPICE Solvated Amino Acids Single Points Dataset v1.1

real    4m3.647s
user    0m20.943s
sys 0m5.486s

If I also retrieve the molecules it takes much longer.

$ time python3 -c "import qcportal as ptl;c=ptl.PortalClient('https://ml.qcarchive.molssi.org');ds=c.get_dataset_by_id(343);print(ds.name);recs=list(ds.iterate_records());mols=c.get_molecules([r.molecule_id for e, s, r, in recs])"
SPICE Solvated Amino Acids Single Points Dataset v1.1

real    52m5.113s
user    44m5.881s
sys 4m7.884s
bennybp commented 7 months ago

Oh my, this is not what I expected. I added some print statements for each part of fetching molecules, and the requests are being handled just fine (100-200 ms). But parsing the JSON into the molecule objects that is taking far too long. I need to look at this ASAP.

Getting 250 molecules
    Request time: 0.18s Return size: 1595493
    Deserialization time: 0.02s
    Model parse time: 113.40s

Almost 2 minutes to convert JSON to 250 molecules is definitely not right.

bennybp commented 7 months ago

Ok I have a PR (#798) for the server that should fix this, but I need to test a bit to make sure it actually fixes the issue

peastman commented 7 months ago

Thanks for the quick fix! Let me know when it's ready to test.

bennybp commented 7 months ago

New release is out and the ML server has been upgraded. Here's what I get now:

time python3 -c "import qcportal as ptl;c=ptl.PortalClient('https://ml.qcarchive.molssi.org');ds=c.get_dataset_by_id(343);print(ds.name);recs=list(ds.iterate_records());mols=c.get_molecules([r.molecule_id for e, s, r, in recs])"
SPICE Solvated Amino Acids Single Points Dataset v1.1

real    2m44.574s
user    0m59.002s
sys 0m3.376s
peastman commented 7 months ago
$ time python3 -c "import qcportal as ptl;c=ptl.PortalClient('https://ml.qcarchive.molssi.org');ds=c.get_dataset_by_id(343);print(ds.name);recs=list(ds.iterate_records());mols=c.get_molecules([r.molecule_id for e, s, r, in recs])"
SPICE Solvated Amino Acids Single Points Dataset v1.1

real    2m58.800s
user    0m23.878s
sys 0m4.934s

Much better!

peastman commented 7 months ago

I'm still getting this error. I've been trying all day to download dataset 348. It invariably fails after anything from a few minutes to a couple of hours. Usually with the Bad Gateway error, occasionally with a read timeout.

bennybp commented 7 months ago

Looking through the logs I do see a hint where the instance is running out of memory. I've increased the memory allocated to it through docker, and also reconfigured the number of processes/threads handling requests for each container. Let's see if that helps.

Ok let me see if I can reproduce this tomorrow on a different instance, but I think it is something subtle like this.

This is annoying but thanks for being patient!

peastman commented 7 months ago

Thanks! I'm trying again.

peastman commented 7 months ago

Success, thanks!

bennybp commented 7 months ago

Yes I haven't seen any additional errors. But the increasing memory usage is unsettling.

Something is happening in SQLAlchemy that is causing increasing memory as time goes on. I can reproduce it with QCFractal, even outside of flask. And I can see what data structures are being held on too long, but it's hard to reproduce it will a small self-contained script.