Closed pat-alt closed 1 year ago
I think this is in line with how Flux
works. For example, I can call the <:Chain
in the fitresult
on a matrix:
julia> mach.fitresult[1](Xmat)
3×150 Matrix{Float64}:
0.280818 0.327379 0.293104 0.520361 0.334125 0.289304 0.52877 … 0.343907 0.274209 0.551496 0.273814 0.286353 0.54921 0.551846
0.339427 0.338713 0.32827 0.322471 0.338675 0.336802 0.319926 0.331148 0.342926 0.309555 0.335115 0.339573 0.310251 0.309612
0.379755 0.333908 0.378626 0.157168 0.327199 0.373894 0.151303 0.324945 0.382865 0.138949 0.391072 0.374074 0.140539 0.138542
So I'd expect the predict
method to work as well, no?
julia> MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = X
julia> predict(mach, Xmat)
150-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, String, UInt8, Float64}:
UnivariateFinite{Multiclass{3}}(setosa=>0.281, versicolor=>0.339, virginica=>0.38)
UnivariateFinite{Multiclass{3}}(setosa=>0.327, versicolor=>0.339, virginica=>0.334)
UnivariateFinite{Multiclass{3}}(setosa=>0.293, versicolor=>0.328, virginica=>0.379)
UnivariateFinite{Multiclass{3}}(setosa=>0.52, versicolor=>0.322, virginica=>0.157)
UnivariateFinite{Multiclass{3}}(setosa=>0.334, versicolor=>0.339, virginica=>0.327)
UnivariateFinite{Multiclass{3}}(setosa=>0.289, versicolor=>0.337, virginica=>0.374)
This was the intended design as most models in MLJ expect tabular data. But there's no reason we cannot support both tables and matrix data. I'll comment on your PR shortly, thank you.
Not sure if this is by design, but calling predict on a matrix throws an error.
Using the example from the README:
The following throws an error related to formatting:
Will create a PR for this issue in a moment that simply adds the following to
core.jl
:Looks redundant but works 😄