sustainable-processes / ORDerly

Chemical reaction data & benchmarks. Extraction and cleaning of data from Open Reaction Database (ORD)
MIT License
65 stars 7 forks source link

Better data generator #122

Closed marcosfelt closed 1 year ago

marcosfelt commented 1 year ago

This improves the data generator for on-the-fly generation of fingerprints. We need to do this because the large fingerprint size (16384 bits) dataset will not fit into memory on most machines.

tl;dr: I used caching on disk, so you only have to run the fingerprint generation once but also don't have to load the whole dataset into memory. To use this, add--cache_train_data=True, --cache_val_data=True and/or --cache_val_data=True to the command line. I also believe that the cache will not be regenerated each run, so you have to manually delete the cache directories (e.g., .tf_train_cache) to get it to regenerate if you've made a change to the training data.

Explanation

Tensorflow has two APIs for datasets: keras.utils.Sequence or tf.data.Dataset. The former is what I used first and is quite straightforward, but the official tensorflow recommendation is that all new code use tf.data.Dataset. Furthermore, I was never able to get parallel processing of the fingerprints working with keras.utils.Sequence. Therefore, I switched over to tf.data.Dataset.

The challenge is that tf.data.Dataset is heavily optimized for computations that can all happen within tensorflow's Autograph (i.e., operations on tensors). Since the fingerprint generation runs outside of tensorflow (i.e., using RDKit), we can't take advantage of the parallelization features of datasets. You'll see in the code that I tried to implement something like this blog post, but I found no significant speed-up in the code.

The trick that worked was to cache the dataset after the first time the fingerprints are generated. Although not obvious in the tensorflow documentation for caching, if you specify a file to cache to on disk, the dataset will only read from the on disk cache the needed data for the current or prefetched batches. This is because the caching uses TFRecords, which are an efficient storage mechanism with random access to any row stored on disk.

In practice, the speed-up from this change was huge. Training throughput was increased 50x from ~400 examples/s to 25k+ examples/s. This means we can now train models with both large batch size and large fingerprint size in short run times.