FluxML / FastAI.jl

Repository of best practices for deep learning in Julia, inspired by fastai
https://fluxml.ai/FastAI.jl
MIT License
588 stars 52 forks source link

FasterAI: a roadmap to user-friendly, high-level interfaces #148

Closed lorenzoh closed 3 years ago

lorenzoh commented 3 years ago

FasterAI is the working name for a new high-level interface for FastAI.jl with the goal of making it easier for beginner users to get started by

Motivation The current documentation examples are decently short and do a good job of showcasing the mid-level APIs. However, they can be daunting for beginner users and could be made much shorter with only a few convenience interfaces.

using FastAI
path = datasetpath("imagenette2-160")
data = loadfolderdata(
    path,
    filterfn=isimagefile,
    loadfn=(loadfile, parentname))
classes = unique(eachobs(data[2]))
method = BlockMethod(
    (Image{2}(), Label(classes)),
    (
        ProjectiveTransforms((128, 128), augmentations=augs_projection()),
        ImagePreprocessing(),
        OneHot()
    )
)
learner = methodlearner(method, data, Models.xresnet18(), callbacks...)
fitonecycle!(learner, 10)
xs, ys = makebatch(method, data, 1:8)
ypreds = learner.model(xs, ys)
plotpredictions(method, xs, ypreds, ys)

The logic can be reduced to the following operations which constitute a step in the basic workflow. A new, high-level interface (codenamed "FasterAI") should make each operation a one-liner.

  1. Dataset: downloading and loading a dataset
  2. Learning method: creating a learning method
  3. Learner: creating a Learner
  4. Training: training the Learner
  5. Visualization: visualizing results after training

The existing abstractions are well-suited for building high-level user-friendly interfaces on top. The above example, and all the learning methods in the docs, could then be written in 5 lines:

data, classes = ImageClfFolders(labelfn=parentname)(datasetpath("imagenette2-160"))
method = ImageClassification((128, 128), classes)
learner = methodlearner(method, data, Models.xresnet18(), callbacks...)
fitonecycle!(learner, 10)
plotpredictions(method, learner)

Importantly, a good deal of customization is retained, and every line can be replaced by the parts in the original example to offer full customizability without affecting the other lines. If you want to change the dataset you're using, only change the first line. If you want to use different training hyperparameters, change line 4 and so on...

It is important to note that there are no changes required to the existing APIs, so FasterAI will be easy to implement while not breaking existing functionality.

Ideas

Following are ideas for improving each of the above steps

Dataset

For some dataset formats like the basic image classification, the loadfolderdata helper already makes it possible to write one-liners for loading a dataset:

data = loadfolderdata(datasetpath("imagenette2-160"), filterfn=isimagefile, loadfn=(loadfile, parentname))

For others, this isn't always possible. Consider the segmentation and multi-label classification examples from the quickstart docs:

df = loadfile(joinpath(path, "train.csv"))
data = (
    mapobs(f -> loadfile(joinpath(path, "train", f)), df.fname),  # images
    map(labelstr -> split(labelstr, ' '), df.labels),              # labels
)
classes = readlines(open(joinpath(path, "codes.txt")))
data = (
    loadfolderdata(joinpath(path, "images"), filterfn=isimagefile, loadfn=loadfile),
    loadfolderdata(joinpath(path, "labels"), filterfn=isimagefile, loadfn=f -> loadmask(f, classes))
)

And even for the single-label classification case, you have to manually call unique(eachobs(data[2])). These steps could be intimidating for users unfamiliar with the data container API.

One solution for this would be to create dataset recipes that encapsulate the logic for loading a data container that is stored in a common format along with metadata. The recipes could still allow for some customization through arguments while keeping a consistent API for every dataset. A recipe is just a configuration object that can be called, for example on a path, returning a data container and metadata:

data, metadata = DatasetRecipe(args...; kwargs...)(path)

data, classes = ImageClfFolders(labelfn=parentname)(datasetpath("imagenette2-160"))
data, classes = SegmentationFolders(
    labelfile=p"codes.txt",
    imagefolder="images",
    labelfolder="labels")(datasetpath("camvid_tiny"))

Additionally, the recipe approach also makes it easier to document the data loading process: Each recipe's docstring describes the expected format, and the configuration options. The recipes also make it possible to give user-friendly error messages when the file format is different than expected.

Looking further, this dataset recipe abstraction could also improve other parts of the workflow.

API examples (see also code above for recipe examples)

Learning method

The first thing to introduce here is a collection of function wrappers for creating common learning methods. Same as the dataset recipes above, this allows documenting them, throwing helpful errors and constraints the number of mistakes possible when trying to adapt an existing example.

ImageClassificationSingle(sz::NTuple{N}, classes) where N = BlockMethod(
    (Image{N}(), Label(classes)),
    (ProjectiveTransforms(sz), ImagePreprocessing(), OneHot())
)

ImageClassificationMulti(sz::NTuple{N}, classes) where N = BlockMethod(
    (Image{N}(), LabelMulti(classes)),
    (ProjectiveTransforms(sz), ImagePreprocessing(), OneHot())
)

ImageSegmentation(sz::NTuple{N}, classes) where N = BlockMethod(
    (Image{N}(), Mask{N}(classes)),
    (ProjectiveTransforms(sz), ImagePreprocessing(), OneHot())
)

Additionally, there could be a function for discovering these learning methods given just block types:

julia> learningmethods(Tuple{Image, Mask})
    [ImageSegmentation,]
julia> learningmethods(Tuple{Image, Any})
    [ImageSegmentation, ImageClassificationSingle, ImageClassificationMulti]

Learner

methodlearner is already a good high-level function that takes care of many things.

Training

fitonecycle!, finetune! and lrfind are high-level enough one-liners for many trianing use cases. Maybe add fitflatcos! for fastai feature parity.

Visualization

plotpredictions already exists and compares predictions vs. targets for supervised tasks. For usability, plotpredictions, plotbatch, and plotsamples need convenience functions that take a learner directly:

plotpredictions(method, learner)
plotoutputs(method, learner)
plotsamples(method, learner)