FluxML / MLJFlux.jl

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox
http://fluxml.ai/MLJFlux.jl/
MIT License
143 stars 17 forks source link
image An interface to the Flux deep learning models for the [MLJ](https://github.com/alan-turing-institute/MLJ.jl) machine learning framework

Stable

Branch Julia CPU CI GPU CI Coverage
master v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage
dev v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage

Code Snippet

using MLJ, MLJFlux, RDatasets, Plots

Grab some data and split into features and target:

iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), rng=123);
X = Float32.(X);      # To optmise for GPUs

Load model code and instantiate an MLJFlux model:

NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux

clf = NeuralNetworkClassifier(
    builder=MLJFlux.MLP(; hidden=(5,4)),
    batch_size=8,
    epochs=50,
    acceleration=CUDALibs()  # for training on a GPU
)

Wrap in "iteration controls":

stop_conditions = [
    Step(1),            # Apply controls every epoch
    NumberLimit(1000),  # Don't train for more than 1000 steps
    Patience(4),        # Stop after 4 iterations of deteriation in validation loss
    NumberSinceBest(5), # Or if the best loss occurred 5 iterations ago
    TimeLimit(30/60),   # Or if 30 minutes has passed
]

validation_losses = []
train_losses = []
callbacks = [
    WithLossDo(loss->push!(validation_losses, loss)),
    WithTrainingLossesDo(losses->push!(train_losses, losses[end])),
]

iterated_model = IteratedModel(
    model=clf,
    resampling=Holdout(fraction_train=0.5); # loss and stopping are based on out-of-sample
    measures=log_loss,
    controls=vcat(stop_conditions, callbacks),
);

Train the wrapped model:

julia> mach = machine(iterated_model, X, y)
julia> fit!(mach)

[ Info: No iteration parameter specified. Using `iteration_parameter=:(epochs)`. 
[ Info: final loss: 0.1284184007796247
[ Info: final training loss: 0.055630706
[ Info: Stop triggered by NumberSinceBest(5) stopping criterion. 
[ Info: Total of 811 iterations. 

Inspect results:

julia> plot(train_losses, label="Training Loss")
julia> plot!(validation_losses, label="Validation Loss", linewidth=2, size=(800,400))