elixir-nx / scidata

Download and normalize datasets related to science
Apache License 2.0
164 stars 13 forks source link

Input pipeline API #13

Open seanmor5 opened 3 years ago

seanmor5 commented 3 years ago

Related to discussion in #11 and would resolve https://github.com/elixir-nx/axon/issues/25

I have been digging in to input pipelines, specifically tf.data and torch.utils.data.DataLoader. I'm more familiar with tf.data, and it's got a more intentionally functional pattern, so I'm mostly biased towards an API very similar to theirs. Here's a recent paper on tf.data as well.

The goal of an efficient input pipeline is to keep accelerators as busy as possible. There are a lot of places bottlenecks can happen (large IO operations, data transfer to GPU, slow input transformations, etc.) so it's important that any implementation be as performance sensitive as possible. tf.data has some interesting benchmarks where they simulate an "infinitely fast" neural network to measure the absolute throughput of their API and I believe they achieve something like 13k images processed per second - it would be interesting to replicate some of these benchmarks, but I won't get ahead of myself.

Input pipelines can be characterized in 3 phases: Extract. Transform. Load. I'll briefly summarize the stages and their challenges.

Extract

This stage is reading data from storage - think loading images from directories or streaming text from files. It's heavily IO bound and slow; however, because most practical datasets are massive, loading the entire dataset into memory is impractical.

Transform

This stage is applying transformations and preprocessing to the input data. This could be anything from image augmentation to applying masks, padding, etc. Most operations are compute intensive; however, because the accelerator should be busy doing the actual training, transformation work is most efficiently offloaded to the CPU.

Load

This stage actually loads data into accelerators. Transferring data from CPU to an accelerator can prove costly, but there are some tricks such as staging / prefetching input buffers to improve performance.

The tf.data main abstraction is the tf.data.Dataset which represents a stateless input pipeline. The tf.data.Dataset is analogous to an Elixir Stream. It can be transformed with functions such as filter, flat_map, batch, map, reduce, etc. and these transformations are fused into a graph that can then be statically optimized. The input pipeline also offers the ability to "prefetch" data (stage for efficient transfer to accelerators), cache data so it's read from memory or faster storage later on, as well as "dynamic" optimizations that tune the "parallelism" and memory usage of the pipeline.

Based on the considerations above, I propose we create an input pipeline abstraction very similar to tf.data based on Streams. Here are my initial thoughts:

First, we can define a struct that stores the actual input / label stream as well as metadata:

defstruct :input, :label, :input_shape, :input_type, :label_shape, :label_type, :supervised

So we can capture shape / type information if necessary. :supervised is true in cases where labels are present and false when they are not.

We'll then have a number of "extract" methods that return new pipelines from a variety of formats. It should also be trivial to create new "extract" methods, but we should cover the most common cases and ensure we have them as optimized as possible:

from_stream(stream, opts \\ []) :: pipeline
from_files(files, opts \\ []) :: pipeline
...

We'll also have a number of transformations. I think we might be able to do some static optimizations and fusions of our own, and we can ensure each transformation is jitted by default.

batch(pipeline, batch_size) :: pipeline
map(pipeline, map_fn) :: pipeline
filter(pipeline, filter_fn) :: pipeline
repeat(pipeline, repeat_size) :: pipeline
shuffle(pipeline, shuffle_size) :: pipeline

We'll also want some "performance" based functions, although I haven't really thought about how these can be most efficiently implemented:

prefetch(pipeline, size) :: pipeline
cache(pipeline, size) :: pipeline

And ways to lazily iterate through the dataset, although this can be done directly on input and label streams most likely. This kind of application IMO suits Elixir very well. I'm not really experienced with GenStage, Broadway, Flow, etc. so maybe they are useful here and somebody else can comment or maybe they are irrelevant and I'll not bring them up again :)

I believe the responsibility of building out this API probably best fits in this library, unless we want to limit the purpose of Scidata to just focusing on datasets and move the pipeline logic elsewhere. Pending any feedback, I can start putting some basic things together in a PR, and then begin working integrating it with Axon.

t-rutten commented 3 years ago

These specs seem great to me!

I think it would be reasonable to build out the pipeline abstraction as well as the pipeline formation of common datasets in this library. What do you think @josevalim and @wojtekmach?

josevalim commented 3 years ago

Thanks @seanmor5 for the write! I am thinking this can also be part of Nx, since loading data is going to be an issue for almost all apps using Nx. But it is hard to say before we add streaming to devices (which is still on my plate).