Open ablaom opened 3 years ago
A related discussion: https://github.com/FluxML/MLJFlux.jl/issues/97 .
@lorenzoh
Huh, I recall at least writing a draft to this issue, seems like it got lost, my bad.
All in all, I think this proposal is doable and useful, however there may not be a clean way to implement this without some additional interfaces in MLDataPattern.jl.
Just to give some context, DataLoaders.jl does not have a DataLoader
struct, instead the DataLoader
function returns a composition of eachobsparallel
, a parallel observation iterator, and batchviewcollated
, which turns a data container of observations into a data container of collated batches (lazily). See the distinction between Data Views and Data Iterators in the MLDataPattern.jl docs.
Now, in order to slice a data loader (i.e. a eachobsparallel(batchviewcollated(data, ...))
) we would have to allow mapping the indexing operation (which would probably be done lazily using datasubset
) through the data iterator eachobsparallel
, meaning roughly:
Base.getindex(dataiter::GetObsParallel, idxs) = GetObsParallel(datasubset(dataiter.data, idxs), dataiter.useprimary)
Base.getindex(dataiter::BufferGetObsParallel, idxs) = BufferGetObsParallel(datasubset(dataiter.data, idxs), dataiter.useprimary)
Here a problem arises: since there can be multiple data iterators implementations, one would have to implement the method for each; having an abstract type AbstractDataIterator
or something and using SetField.jl would help here. The other question is whether the indexing would be done on the batch view or the underlying observations. Semantically, the latter makes more sense I guess but would be harder to cleanly implement.
@darsnack is this the best way to do it? Did I miss something?
EDIT: just saw that LearnBase.jl actually has abstract types AbstractBatchView
and DataIterator
which could be used to implement this generically, though I'm not sure if they're used consistently. DataLoaders.jl would have to add those as supertypes but that shouldn't be a problem.
That sounds great. Hope this works out.
I have nothing to add other than it should probably be:
MLDataPattern.datasubset(dataiter::GetObsParallel, idxs) = GetObsParallel(datasubset(dataiter.data, idxs), dataiter.useprimary)
MLDataPattern.datasubset(dataiter::BufferGetObsParallel, idxs) = BufferGetObsParallel(datasubset(dataiter.data, idxs), dataiter.useprimary)
In MLDataPattern.jl, we distinguish between subsetting a selection of indices vs. accessing them. Using getindex
here seems like it will conflate the two.
Maybe we should add [Buffer]GetObsParallel <: DataIterator
and BatchViewCollated <: AbstractBatchview
and then implement it as:
MLDataPattern.datasubset(dataiter::DataIterator, idxs) = setdata(iter, datasubset(getdata(dataiter), idxs))
Though that would require the getdata
and setdata
lenses and this would have to be a change to MLDataPattern.jl itself. Do you think this would be a useful addition Kyle?
Also, Anthony, would the datasubset
be useful for you? I'm note sure if MLJ already has a dep on MLDataPattern.jl.
I am not sure we want to do that for every iterator. In fact, I am now doubting my previous suggestions. Almost always, a data iterator is the last piece in the data loading pipeline, so a transformation like what's being described here is safe. But it does break away from MLDataPattern.jl. Basically, under the pattern introduced by MLDataPattern.jl, something like
dl = DataLoader(...)
x = datasubset(dl, 1:5)
means "load a piece of data (a batch) using DataLoader
then give me the first 5 samples of that batch." This interpretation extends beyond datasubset
or DataLoader
. You could apply split
or kfold
on top of eachbatch
and expect to split the batch using the same tools you used to split the original data. This simple and consistent pattern is one of the selling points of MLDataPattern.jl and partly why it is so flexible. I think we need to think more carefully if we want to deviate from that.
Personally, I have encountered the same problem as what @ablaom is describing for MLJ in my own code. I want to write a top level script that looks like
data = # ...
train, test = split(data, 0.8)
bs = 64
trainloader, testloader = eachbatch(train, bs), eachbatch(test, bs)
# some code
# calling my custom training loop
my_train!(..., trainloader, ...)
Now this pattern restricts my_train!
to only iterating trainloader
. Any kind of loop that needs to do something to the underlying data being iterated by eachbatch
can't receive trainloader
as written above. It needs to receive train
. So then you get the following pattern:
data = # ...
train, test = split(data, 0.8)
bs = 64
# some code
# add a bs keyword
my_train!(..., train, ...; bs = bs)
Since this is my user code, I just modify my_train!
to call eachbatch
after it does its special manipulation.
But, this presents a problem for packages that expose training loops. FastAI.jl just takes on the MLDataPattern.jl dependency and calls eachbatch
for the user. But this may not work for all packages, and it doesn't seem like it would for MLJ.
@ablaom one option is to make this distinction clear in MLJ. What I mean is that the user passes in some data
which is some (possible manipulated) data without the loader (i.e. DataLoader
). Separately, they pass in an iterator constructor that is like data -> DataLoader(data; bs = bs, ...)
. Then the MLJ package code can do additional things to data like subsetting into multiple folds, etc., and call the iterator constructor on each subset when it runs the training for each fold.
If that isn't an acceptable solution, then we need to introduce a different API into MLDataPattern that says "do X to the data wrapped by the iterator" instead of "do X to the output of the iterator."
As you say, data iterators are usually only applied once, and last, since they don't change the what of the observations but just the how of how they are loaded. Considering that, I think it would be useful to treat as data iterators as data containers where all container-specific functions are mapped through to the wrapped data container.
For example:
# load some data container
data = ...
# transform the observations
data = data |> shuffleobs |> batchviewcollated
# create an iterator
dataiter = eachobs(data)
# now transformations of the observations should be passed through to the container `eachobs` wraps:
dataiter = mapobs(gpu, dataiter) # equivalent to `eachobs(mapobs(gpu, dataiter.data))`
@darsnack @lorenzoh Am really appreciating this discussion! Will try to add some comments early next week. Thanks again.
Okay, I've had a chance to educate myself a little better about MLDataPattern and DataLoaders. It may make more sense to integrate MLJ with MLDataPattern first, after all. This will take some more thought on my part, and I have to find the time somewhere...
Thanks for this discussion!
I should like to see enhancements in the MLJ ecosystem that allow models to work with out-of-memory data, and wonder if DataLoaders might be a good tool here. The main issue, as far as I can tell, is around observation resampling, which is the basis of performance evaluation, and by corollary, hyper-parameter optimization - meta-algorithms that can be applied to all current MLJ models.
So, I wondered if it is be possible for this package to implement
getindex(::DataLoader, indxs)
, forindxs isa Union{Colon,Integer,AbstractVector{<:Integer}
, returning an object with the same interface.This could be a new
SubDataLoader
object, but in any case it would be important for the originaleltype
to be knowable (assumingeltype
it is implemented for the original object, or you add it as a type parameter).Since the
DataLoader
type already requires the implementation of the "random accesss"getobs
method, this looks quite doable to me.I realize that for large datasets (the main use case for
DataLoaders
) resampling is often a simple holdout. However, becauseHoldout
is implemented as a special case of more general resampling strategies (CV
, etc) it would be rather messy to addDataLoader
support for just that case without the slicing feature.Would there be any sympathy among current developers for such an enhancement? Perhaps there is an alternative solution to my issue?
BTW, I don't really see how the suggestion in the docs to apply observation resampling before data is wrapped in a DataLoader could really work effectively in the MLJ context, as the idea is that resampling should remain completely automated. (It also seems from the documentation that this requires bringing the data into memory...?) But I maybe I'm missing something there.