JuliaAI / MLJ.jl

A Julia machine learning framework
1.8k stars 157 forks source link

predict fails in MLJ 0.4 with DecisionTreeClassifier #226

Closed juliohm closed 5 years ago

juliohm commented 5 years ago

Describe the bug

MLJBase.predict(model, theta, X)

stopped working.

To Reproduce

using MLJ
using MLJBase

@load DecisionTreeClassifier()

function g()
  X = rand(100,2)
  y = rand(Int, 100)
  m = DecisionTreeClassifier()
  θ, _, __ = MLJBase.fit(m, 0, X, y)
  MLJBase.predict(m, θ, X)


Expected behavior The code was working fine in MLJ < 0.4

Additional context Julia v1.2

tlienart commented 5 years ago

Ok the error I'm seeing when reproducing this is:

julia> MLJBase.predict(m, θ, X)
ERROR: classes must have CategoricalValue or CategoricalString type.
 [1] error(::String) at ./error.jl:33
 [2] UnivariateFinite(::Array{Int64,1}, ::Array{Float64,1}) at /Users/tlienart/.julia/dev/MLJBase/src/distributions.jl:111
 [3] (::MLJModels.DecisionTree_.var"##4#5"{DecisionTreeClassifier,Array{Int64,1},Array{Float64,2}})(::Int64) at ./none:0
 [4] iterate at ./generator.jl:47 [inlined]
 [5] collect at ./array.jl:620 [inlined]
 [6] predict(::DecisionTreeClassifier, ::Tuple{DecisionTree.Node{Float64,Int64},Array{Int64,1}}, ::Array{Float64,2}) at /Users/tlienart/.julia/dev/MLJModels/src/DecisionTree.jl:173
 [7] top-level scope at REPL[7]:1

adding a categorical to y fixes the issue

X = rand(100, 2)
y = categorical(rand(Int, 100))

will work

juliohm commented 5 years ago

Thank you. Is the fit/predict API described formally somewhere? I would like to be able to handle all possible cases as I expand the framework. Right now some decisions like MLJBase.table and MLJBase.categorical are not very clear. I don't understand why we need to explicitly call categorical, isn't the Int type a Finite scitype? Can't we proceed without converting to categorical manually?

juliohm commented 5 years ago

Right now, having to deal with the classification case manually doesn't seem ideal. As you can see, the general fit/predict I was using didn't need to know if I was in a classification versus regression task. Now I need to insert a new "if" statement in the code to handle this special case:

Can we leverage multiple-dispatch to provide a higher-level abstraction? What am I missing?

ablaom commented 5 years ago


No, in MLJ you cannot represent categorical variables with integers. See the bottom of Getting Started and the table at the end of the ScientificTypes.jl README for details.

You can coerce the data in a vector or column so that it has the type expected by MLJ using the coerce method, after inspecting using schema. See the ScientificTypes.jl README for details. These methods are re-exported to MLJ and MLJBase. There is an open issue to improve this workflow further here

The decision to not allow integer representation of categorical is deliberate. A typical problem encountered with integers for categorical is this: If there are lots of levels (eg, supermarket items, medical diagnoses, words in a language) one trains on a subsample only to discover that the evaluation set contains levels not seen in the train, and your evaluation crashes. In some frameworks (Scitlearn, for example) this issue is completely ignored, which causes no end of anguish to newcomers. To avoid such problems, one has either to pass around around metadata on the number of levels, or use a suitable wrapper (like CategoricalVectors, from CategoricalArrays) or use a dedicated categorical value type (such as CategoricalValue from the same package) where every element points to a pool of all the levels (and their labels). We have chosen the last option, but as calling getindex on a CategoricalVector element returns a CategorialValue, one can use also present the wrapped objects to MLJ. BTW, the R frameworks all use dedicated categorical element types.

Another reason for using the CategoricalVectors is to do with interpretability/usability. It's very convenient to keep the original handle on the levels, and these handles propagate when one one-hot encodes and so forth. In the physical sciences, this usually is not such an issue, but elsewhere where you have lots of levels, it is more important.

Note that most libraries for ingesting tabular data (eg, CSV) allow you to automatically convert String type data to CategoricalVectors, as a matter of course.

Thank you. Is the fit/predict API described formally somewhere?

The definitive document is https://alan-turing-institute.github.io/MLJ.jl/stable/adding_models_for_general_use/.

Regarding this code:

function g()
  X = rand(100,2)
  y = rand(Int, 100)
  m = DecisionTreeClassifier()
  θ, _, __ = MLJBase.fit(m, 0, X, y)
  MLJBase.predict(m, θ, X)

@tlienart has correctly identified the cause of the crash but you should keep in mind that you use a matrix for X at your own risk. A table is required:

julia> info("DecisionTreeClassifier")
OrderedCollections.LittleDict{Symbol,Any,Array{Symbol,1},Array{Any,1}} with 13 entries:
  :name             => "DecisionTreeClassifier"
  :package_name     => "DecisionTree"
  :package_url      => "https://github.com/bensadeghi/DecisionTree.jl"
  :package_license  => "unkown"
  :load_path        => "MLJModels.DecisionTree_.DecisionTreeClassifier"
  :is_wrapper       => false
  :is_pure_julia    => true
  :package_uuid     => "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
  :supports_weights => false
  :is_supervised    => true
  :is_probabilistic => true
  :input_scitype    => Table{#s13} where #s13<:(AbstractArray{#s12,1} where #s12<:C…
  :target_scitype   => AbstractArray{#s24,1} where #s24<:Finite
ablaom commented 5 years ago

Closing as no MLJ bug identified. Feel free to re-open if one is identified, or open a relevant design discussion instead.

juliohm commented 5 years ago

Just sharing that I am really appreciating this brainstorming process. Today I fixed a long-standing issue in GeoStats.jl. Now every spatial data object satisfies the Tables.jl API. I've updated my learn/perform to use this API. This means that I can pass my spatial data directly and it is interpreted as if it was a table in the MLJ.jl methods. 💯