FluxML / MLJFlux.jl

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox
http://fluxml.ai/MLJFlux.jl/
MIT License
143 stars 17 forks source link

Calling predict on matrix throws error #218

Closed pat-alt closed 1 year ago

pat-alt commented 1 year ago

Not sure if this is by design, but calling predict on a matrix throws an error.

Using the example from the README:

# Example from README:
using MLJ
import RDatasets
iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), colname -> true, rng=123);
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
clf = NeuralNetworkClassifier()
mach = machine(clf, X, y)
fit!(mach)

The following throws an error related to formatting:

julia> Xmat = permutedims(MLJ.matrix(X))
4×150 Matrix{Float64}:
 6.7  5.7  7.2  4.4  5.6  6.5  4.4  6.1  …  6.1  6.7  5.0  7.6  6.3  5.1  5.0
 3.3  2.8  3.0  2.9  2.5  3.0  3.0  2.9     2.8  2.5  3.5  3.0  2.5  3.8  3.6
 5.7  4.1  5.8  1.4  3.9  5.2  1.3  4.7     4.0  5.8  1.3  6.6  5.0  1.6  1.4
 2.1  1.3  1.6  0.2  1.1  2.0  0.2  1.4     1.3  1.8  0.3  2.1  1.9  0.2  0.2

julia> predict(mach, Xmat)
ERROR: 

────────────────────────────────── MethodError ─────────────────────────────────
╭──── Error Stack ─────────────────────────────────────────────────────────────╮
│                                                                              │
│                                                                              │
│  ╭──────────────────────────────────────────────────────────────────╮        │
│  │                                                                  │        │
│  │  (1) top-level scope                                             │        │
│  │      ~/.julia/packages/CUDA/BbliS/src/initialization.j           │        │
│  │      l:52                                                        │        │
│  │        │ ╭────────────────────────────────────────────────────   │        │
│  │  ──────╮                                                         │        │
│  │        ╰─│    50     quote                                       │        │
│  │             │                                                    │        │
│  │          │    51         try                                     │        │
│  │             │                                                    │        │
│  │          │  ❯ 52             $(ex)                               │        │
│  │             │                                                    │        │
│  │          │    53         finally                                 │        │
│  │             │                                                    │        │
│  │          │    54             $task_local_state()...              │        │
│  │             │                                                    │        │
│  │          ╰─────────────────────── error line ─────────────────   │        │
│  │  ──────╯                                                         │        │
│  │                                                                  │        │
│  ╰──────────────────────────────────────────────────── TOP LEVEL ───╯        │
│                                                                              │
│     (2) top-level scope                                                      │
│         REPL[43]:1                                                           │
│                                                                              │
│     ────────────────────────────────────────────────────────────────────     │
│                         Skipped 1 frames in MLJBase                          │
│     ────────────────────────────────────────────────────────────────────     │
│                                                                              │
│  ╭──────────────────────────────────────────────────────────────────╮        │
│  │                                                                  │        │
│  │  (5) reformat(X::Matrix)                                         │        │
│  │      ~/.julia/packages/MLJFlux/6XVNm/src/core.jl:145             │        │
│  │        │ ╭────────────────────────────────────────────────────   │        │
│  │  ──────╮                                                         │        │
│  │        ╰─│    143 nrows(y::AbstractVector) = length(y)           │        │
│  │             │                                                    │        │
│  │          │    144                                                │        │
│  │             │                                                    │        │
│  │          │  ❯ 145 reformat(X) = reformat(X, scitype(X))          │        │
│  │             │                                                    │        │
│  │          │    146                                                │        │
│  │             │                                                    │        │
│  │          │    147 # ---------------------------------            │        │
│  │             │                                                    │        │
│  │          ╰─────────────────────── error line ─────────────────   │        │
│  │  ──────╯                                                         │        │
│  │                                                                  │        │
│  ╰─────────────────────────────────────────────────── ERROR LINE ───╯        │
│                                                                              │
╰──── Error Stack ─────────────────────────────────────────────────────────────╯
╭───────────────────────────────── MethodError ────────────────────────────────╮
│                                                                              │
│  No method matching `reformat` with arguments types:                         │
│  ::Matrix, ::DataType                                                        │
│                                                                              │
│  Alternative candidates:                                                     │
│    reformat(::Any)                                                           │
│    reformat(::Any, ::Type{<:Table})                                          │
│    reformat(::Any, ::Type{<:GrayImage})                                      │
│                                                                              │
╰──────────────────────────────────────────────────────────────────────────────╯

Will create a PR for this issue in a moment that simply adds the following to core.jl:

reformat(X, ::Type{<:AbstractMatrix}) = X

Looks redundant but works 😄

pat-alt commented 1 year ago

I think this is in line with how Flux works. For example, I can call the <:Chain in the fitresult on a matrix:

julia> mach.fitresult[1](Xmat)
3×150 Matrix{Float64}:
 0.280818  0.327379  0.293104  0.520361  0.334125  0.289304  0.52877   …  0.343907  0.274209  0.551496  0.273814  0.286353  0.54921   0.551846
 0.339427  0.338713  0.32827   0.322471  0.338675  0.336802  0.319926     0.331148  0.342926  0.309555  0.335115  0.339573  0.310251  0.309612
 0.379755  0.333908  0.378626  0.157168  0.327199  0.373894  0.151303     0.324945  0.382865  0.138949  0.391072  0.374074  0.140539  0.138542

So I'd expect the predict method to work as well, no?

julia> MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = X

julia> predict(mach, Xmat)
150-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, String, UInt8, Float64}:
 UnivariateFinite{Multiclass{3}}(setosa=>0.281, versicolor=>0.339, virginica=>0.38)
 UnivariateFinite{Multiclass{3}}(setosa=>0.327, versicolor=>0.339, virginica=>0.334)
 UnivariateFinite{Multiclass{3}}(setosa=>0.293, versicolor=>0.328, virginica=>0.379)
 UnivariateFinite{Multiclass{3}}(setosa=>0.52, versicolor=>0.322, virginica=>0.157)
 UnivariateFinite{Multiclass{3}}(setosa=>0.334, versicolor=>0.339, virginica=>0.327)
 UnivariateFinite{Multiclass{3}}(setosa=>0.289, versicolor=>0.337, virginica=>0.374)
ablaom commented 1 year ago

This was the intended design as most models in MLJ expect tabular data. But there's no reason we cannot support both tables and matrix data. I'll comment on your PR shortly, thank you.