Open seanmor5 opened 3 years ago
These specs seem great to me!
I think it would be reasonable to build out the pipeline abstraction as well as the pipeline formation of common datasets in this library. What do you think @josevalim and @wojtekmach?
Thanks @seanmor5 for the write! I am thinking this can also be part of Nx, since loading data is going to be an issue for almost all apps using Nx. But it is hard to say before we add streaming to devices (which is still on my plate).
Related to discussion in #11 and would resolve https://github.com/elixir-nx/axon/issues/25
I have been digging in to input pipelines, specifically tf.data and torch.utils.data.DataLoader. I'm more familiar with
tf.data
, and it's got a more intentionally functional pattern, so I'm mostly biased towards an API very similar to theirs. Here's a recent paper ontf.data
as well.The goal of an efficient input pipeline is to keep accelerators as busy as possible. There are a lot of places bottlenecks can happen (large IO operations, data transfer to GPU, slow input transformations, etc.) so it's important that any implementation be as performance sensitive as possible.
tf.data
has some interesting benchmarks where they simulate an "infinitely fast" neural network to measure the absolute throughput of their API and I believe they achieve something like 13k images processed per second - it would be interesting to replicate some of these benchmarks, but I won't get ahead of myself.Input pipelines can be characterized in 3 phases: Extract. Transform. Load. I'll briefly summarize the stages and their challenges.
Extract
This stage is reading data from storage - think loading images from directories or streaming text from files. It's heavily IO bound and slow; however, because most practical datasets are massive, loading the entire dataset into memory is impractical.
Transform
This stage is applying transformations and preprocessing to the input data. This could be anything from image augmentation to applying masks, padding, etc. Most operations are compute intensive; however, because the accelerator should be busy doing the actual training, transformation work is most efficiently offloaded to the CPU.
Load
This stage actually loads data into accelerators. Transferring data from CPU to an accelerator can prove costly, but there are some tricks such as staging / prefetching input buffers to improve performance.
The
tf.data
main abstraction is thetf.data.Dataset
which represents a stateless input pipeline. Thetf.data.Dataset
is analogous to an ElixirStream
. It can be transformed with functions such asfilter
,flat_map
,batch
,map
,reduce
, etc. and these transformations are fused into a graph that can then be statically optimized. The input pipeline also offers the ability to "prefetch" data (stage for efficient transfer to accelerators), cache data so it's read from memory or faster storage later on, as well as "dynamic" optimizations that tune the "parallelism" and memory usage of the pipeline.Based on the considerations above, I propose we create an input pipeline abstraction very similar to
tf.data
based on Streams. Here are my initial thoughts:First, we can define a struct that stores the actual input / label stream as well as metadata:
So we can capture shape / type information if necessary.
:supervised
istrue
in cases where labels are present andfalse
when they are not.We'll then have a number of "extract" methods that return new pipelines from a variety of formats. It should also be trivial to create new "extract" methods, but we should cover the most common cases and ensure we have them as optimized as possible:
We'll also have a number of transformations. I think we might be able to do some static optimizations and fusions of our own, and we can ensure each transformation is jitted by default.
We'll also want some "performance" based functions, although I haven't really thought about how these can be most efficiently implemented:
And ways to lazily iterate through the dataset, although this can be done directly on input and label streams most likely. This kind of application IMO suits Elixir very well. I'm not really experienced with
GenStage
,Broadway
,Flow
, etc. so maybe they are useful here and somebody else can comment or maybe they are irrelevant and I'll not bring them up again :)I believe the responsibility of building out this API probably best fits in this library, unless we want to limit the purpose of Scidata to just focusing on datasets and move the pipeline logic elsewhere. Pending any feedback, I can start putting some basic things together in a PR, and then begin working integrating it with Axon.