JuliaAI / MLJ.jl

A Julia machine learning framework
https://juliaai.github.io/MLJ.jl/
Other
1.78k stars 157 forks source link

Issue with Documentation/Example - DecisionTreeClassifier again... #156

Closed sdwfrost closed 5 years ago

sdwfrost commented 5 years ago

In trying to run the Iris/DecisionTree example, I get the following error in fit:

MethodError: no method matching build_tree(::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}}, ::Array{Float64,2}, ::Float64, ::Int64, ::Int64, ::Int64, ::Float64)
Closest candidates are:
  build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27
  build_tree(!Matched::Array{T,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any, ::Any; rng) where {S, T} at /home/simon/.julia/dev/DecisionTree/src/classification/main.jl:83
  build_tree(!Matched::Array{T<:Float64,1}, ::Array{S,2}, ::Any, ::Any, ::Any, ::Any) where {S, T<:Float64} at /home/simon/.julia/dev/DecisionTree/src/regression/main.jl:27
  ...

Stacktrace:
 [1] fit(::DecisionTreeClassifier, ::Int64, ::DataFrame, ::CategoricalArray{String,1,UInt8,String,CategoricalString{UInt8},Union{}}) at /home/simon/.julia/dev/MLJModels/src/DecisionTree.jl:110
 [2] #fit!#3(::Array{Int64,1}, ::Int64, ::Bool, ::Function, ::Machine{DecisionTreeClassifier}) at /home/simon/.julia/dev/MLJ/src/machines.jl:131
 [3] (::getfield(StatsBase, Symbol("#kw##fit!")))(::NamedTuple{(:rows,),Tuple{Array{Int64,1}}}, ::typeof(fit!), ::Machine{DecisionTreeClassifier}) at ./none:0
 [4] top-level scope at In[6]:2

I've tried updating DecisionTree, MLJ and MLJModels to the latest versions, and I believe I'm looking at the right docs...I've tried using both RDatasets and the task interface.

Gist here

davidbp commented 5 years ago

I had the same issue.

If you do y=CategoricalArray(y) before fit it should work. If you update to the latest mlj version you don't need to do that.

Here you can an example I tested a couple of days ago.

https://github.com/davidbp/learn_julia/blob/master/MLJ/01_getting_started_classification.ipynb

sdwfrost commented 5 years ago

That's not it: the following:

using MLJ
using RDatasets
iris = dataset("datasets", "iris")
iris[:,5]

gives a 150-element CategoricalArray{String,1,UInt8}, which is exactly the same as what you get by doing y=CategoricalArray(y). I'm also on the latest MLJ.

ablaom commented 5 years ago

Can you please send the output of ‘]status -m’ for the environment you are working in?

sdwfrost commented 5 years ago

Thanks! I was actually looking for the command that gave the environment:

  [621f4979] AbstractFFTs v0.4.1
  [1520ce14] AbstractTrees v0.2.1
  [79e6a3ab] Adapt v0.4.2
  [7d9fca2a] Arpack v0.3.1
  [4fba245c] ArrayInterface v0.1.1
  [bf4720bc] AssetRegistry v0.1.0
  [c52e3926] Atom v0.8.7
  [67c07d97] Automa v0.8.0
  [13072b0f] AxisAlgorithms v1.0.0
  [39de3d68] AxisArrays v0.3.0
  [fbb218c0] BSON v0.2.3
  [aae01518] BandedMatrices v0.9.2
  [76274a88] Bijectors v0.3.0
  [9e28174c] BinDeps v0.8.10
  [b99e7846] BinaryProvider v0.5.4
  [37cfa864] BioCore v2.0.5
  [47718e42] BioGenerics v0.1.0
  [7e6ae17a] BioSequences v1.1.0
  [3c28c6f8] BioSymbols v3.1.0
  [a74b3585] Blosc v0.5.1
  [764a87c0] BoundaryValueDiffEq v2.2.3
  [e1450e63] BufferedStreams v1.0.0
  [631607c0] CMake v1.1.2
  [d5fb7624] CMakeWrapper v0.2.3
  [00ebfdb7] CSTParser v0.5.2
  [336ed68f] CSV v0.5.5
  [3895d2a7] CUDAapi v0.6.3
  [c5f51814] CUDAdrv v3.0.1
  [be33ccc6] CUDAnative v2.1.3
  [49dc2e85] Calculus v0.4.1
  [7057c7e9] Cassette v0.2.3
  [324d7699] CategoricalArrays v0.5.4
  [aaaa29a8] Clustering v0.13.1
  [53a63b46] CodeTools v0.6.4
  [da1fd8a2] CodeTracking v0.5.7
  [944b1d66] CodecZlib v0.5.2
  [19ecbf4d] Codecs v0.5.0
  [3da002f7] ColorTypes v0.8.0
  [5ae59095] Colors v0.9.5
  [861a8166] Combinatorics v0.7.0
  [bbf7d656] CommonSubexpressions v0.2.0
  [34da2185] Compat v2.1.0
  [8f4d0f93] Conda v1.3.0
  [d38c429a] Contour v0.5.1
  [a8cc5b0e] Crayons v4.0.0
  [3a865a2d] CuArrays v1.0.2
  [717857b8] DSP v0.5.2
  [d58978e5] Dagger v0.8.0
  [0fe7c1db] DataArrays v0.7.0
  [a93c6f00] DataFrames v0.18.3
  [864edb3b] DataStructures v0.15.0
  [e7dc6d0d] DataValues v0.4.7
  [7806a523] DecisionTree v0.8.1+ [`~/.julia/dev/DecisionTree`]
  [bcd4f6db] DelayDiffEq v5.3.0
  [2b5f629d] DiffEqBase v5.10.1
  [459566f4] DiffEqCallbacks v2.5.2
  [01453d9d] DiffEqDiffTools v0.10.1
  [5a0ffddc] DiffEqFinancial v2.1.0
  [c894b116] DiffEqJump v6.1.1
  [78ddff82] DiffEqMonteCarlo v0.14.0
  [77a26b50] DiffEqNoiseProcess v3.3.1
  [9fdde737] DiffEqOperators v3.5.0
  [055956cb] DiffEqPhysics v3.1.0
  [163ba53b] DiffResults v0.0.4
  [b552c78f] DiffRules v0.0.10
  [0c46a032] DifferentialEquations v6.4.0
  [c619ae07] DimensionalPlotRecipes v0.2.0
  [b4f34e82] Distances v0.8.0
  [31c24e10] Distributions v0.20.0
  [33d173f1] DocSeeker v0.2.0
  [ffbed154] DocStringExtensions v0.7.0
  [497a8b3b] DoubleFloats v0.9.1
  [fdbdab4c] ElasticArrays v0.4.0
  [2904ab23] ElasticPDMats v0.2.1
  [d4d017d3] ExponentialUtilities v1.4.0
  [8f5d6c58] EzXML v0.9.1
  [7a1cc6ca] FFTW v0.2.4
  [442a2c76] FastGaussQuadrature v0.3.3
  [a0c94c4b] FastaIO v0.5.0
  [5789e2e9] FileIO v1.0.6
  [1a297f60] FillArrays v0.6.3
  [53c48c17] FixedPointNumbers v0.5.3
  [587475ba] Flux v0.8.3
  [f6369f11] ForwardDiff v0.10.3
  [da1fdf0e] FreqTables v0.3.1
  [069b7b12] FunctionWrappers v1.0.0
  [de31a74c] FunctionalCollections v0.5.0
  [38e38edf] GLM v1.1.1
  [0c68f7d7] GPUArrays v0.7.1
  [28b8d3ca] GR v0.40.0
  [92fee26a] GZip v0.5.0
  [891a1506] GaussianProcesses v0.9.0
  [01680d73] GenericSVD v0.2.1
  [c145ed77] GenericSchur v0.2.3
  [4d00f742] GeometryTypes v0.7.5
  [c27321d9] Glob v1.2.0
  [f67ccb44] HDF5 v0.11.1
  [cd3eb016] HTTP v0.8.2
  [0862f596] HTTPClient v0.2.1
  [9fb69e20] Hiccup v0.2.2
  [d9be37ee] Homebrew v0.7.1
  [7073ff75] IJulia v1.18.1
  [1cb3b9ac] IndexableBitVectors v1.0.0
  [6deec6e2] IndexedTables v0.12.0
  [83e8ac13] IniFile v0.5.0
  [a98d9a8b] Interpolations v0.12.2
  [8197267c] IntervalSets v0.3.1
  [524e6230] IntervalTrees v1.0.0
  [c8e1da08] IterTools v1.1.1
  [42fd0dbc] IterativeSolvers v0.8.1
  [82899510] IteratorInterfaceExtensions v1.0.0
  [682c06a0] JSON v0.20.0
  [a93385a2] JuliaDB v0.12.0
  [aa1ae85d] JuliaInterpreter v0.6.0
  [e5e0dc1b] Juno v0.7.0
  [5ab0869b] KernelDensity v0.5.1
  [2d691ee1] LIBLINEAR v0.5.1
  [b1bec4e5] LIBSVM v0.3.1
  [929cbde3] LLVM v1.1.1
  [7c4cb9fa] LNR v0.2.0
  [b964fa9f] LaTeXStrings v1.0.3
  [50d2b5c4] Lazy v0.13.2
  [5078a376] LazyArrays v0.9.0
  [7f8f8fb0] LearnBase v0.2.2
  [b27032c2] LibCURL v0.5.1
  [522f3ed2] LibExpat v0.5.0
  [b13ce0c6] LibSndFile v2.0.0
  [6f1fad26] Libtask v0.3.0
  [2ec943e9] Libz v1.0.0
  [d3d80556] LineSearches v7.0.1
  [30fc2ffe] LossFunctions v0.5.1
  [c7f686f2] MCMCChains v0.3.10
  [add582a8] MLJ v0.2.3 [`~/.julia/dev/MLJ`]
  [a7f614a8] MLJBase v0.2.2
  [d491faf4] MLJModels v0.2.3 [`~/.julia/dev/MLJModels`]
  [1914dd2f] MacroTools v0.5.0
  [dbb5928d] MappedArrays v0.2.1
  [a3b82374] MatrixFactorizations v0.0.4
  [739be429] MbedTLS v0.6.8
  [442fdcdd] Measures v0.3.0
  [e89f7d12] Media v0.5.0
  [f9f48841] MemPool v0.2.0
  [e1d29d7a] Missings v0.4.1
  [78c3b35d] Mocking v0.5.7
  [46d2c3a1] MuladdMacro v0.2.1
  [f9640e96] MultiScaleArrays v1.4.0
  [d41bc354] NLSolversBase v7.3.1
  [2774e3e8] NLsolve v4.0.0
  [872c559c] NNlib v0.6.0
  [77ba4419] NaNMath v0.3.2
  [9bbee03b] NaiveBayes v0.4.0
  [86f7a689] NamedArrays v0.9.2
  [b8a86587] NearestNeighbors v0.4.3
  [4d1e1d77] Nullables v0.0.8
  [510215fc] Observables v0.2.3
  [6fe1bfb0] OffsetArrays v0.11.0
  [a15396b6] OnlineStats v0.23.0
  [925886fa] OnlineStatsBase v0.10.2
  [429524aa] Optim v0.18.1
  [bac558e1] OrderedCollections v1.1.0
  [1dea7af3] OrdinaryDiffEq v5.8.1
  [90014a1f] PDMats v0.9.7
  [d96e819e] Parameters v0.10.3
  [69de0a69] Parsers v0.3.5
  [06bb1623] PenaltyFunctions v0.1.2
  [fa939f87] Pidfile v1.1.0
  [ccf2f8ad] PlotThemes v0.3.0
  [995b91a9] PlotUtils v0.5.8
  [91a5bcdd] Plots v0.25.1
  [e409e4f3] PoissonRandom v0.4.0
  [f27b6e38] Polynomials v0.5.2
  [2dfb63ee] PooledArrays v0.5.2
  [85a6dd25] PositiveFactorizations v0.2.2
  [92933f4c] ProgressMeter v1.0.0
  [438e738f] PyCall v1.91.2
  [d330b81b] PyPlot v2.8.1
  [1fd47b50] QuadGK v2.0.4
  [be4d8f0f] Quadmath v0.4.0
  [df47a6cb] RData v0.6.0
  [ce6b1742] RDatasets v0.6.2
  [e6cf234a] RandomNumbers v1.3.0
  [b3c3ace0] RangeArrays v0.3.1
  [c84ed2f1] Ratios v0.3.1
  [3cdcf5f2] RecipesBase v0.6.0
  [731186ca] RecursiveArrayTools v0.20.0
  [f2c3362d] RecursiveFactorization v0.0.1
  [189a3867] Reexport v0.2.0
  [cbe49d4c] RemoteFiles v0.2.1
  [ae029012] Requires v0.5.2
  [ae5879a3] ResettableStacks v0.6.0
  [79098fc4] Rmath v0.5.0
  [f2b01f46] Roots v0.8.1
  [bd7594eb] SampledSignals v2.0.0
  [3646fa90] ScikitLearn v0.5.0
  [6e75b9c4] ScikitLearnBase v0.4.1
  [992d4aef] Showoff v0.2.1
  [b85f4697] SoftGlobalScope v1.0.10
  [a2af1166] SortingAlgorithms v0.3.1
  [276daf66] SpecialFunctions v0.7.2
  [90137ffa] StaticArrays v0.11.0
  [2913bbd2] StatsBase v0.30.0
  [4c63d2b9] StatsFuns v0.8.0
  [3eaba693] StatsModels v0.5.0
  [9672c7b4] SteadyStateDiffEq v1.4.0
  [789caeaf] StochasticDiffEq v6.4.0
  [88034a9c] StringDistances v0.3.2
  [09ab397b] StructArrays v0.3.4
  [c3572dad] Sundials v3.6.0
  [fd094767] Suppressor v0.1.1
  [7522ee7d] SweepOperator v0.2.0
  [3783bdb8] TableTraits v1.0.0
  [382cd787] TableTraitsUtils v0.4.0
  [bd369af6] Tables v0.2.5
  [e0df1984] TextParse v0.9.1
  [f269a46b] TimeZones v0.9.1
  [a759f4b9] TimerOutputs v0.5.0
  [0796e94c] Tokenize v0.5.3
  [37b6cedf] Traceur v0.3.0
  [9f7883ad] Tracker v0.2.2
  [3bb67fe8] TranscodingStreams v0.9.4
  [a2a6695c] TreeViews v0.3.0
  [fce5fe82] Turing v0.6.17
  [7200193e] Twiddle v1.1.0
  [30578b45] URIParser v0.4.0
  [1986cc42] Unitful v0.15.0
  [81def892] VersionParsing v1.1.3
  [ea10d353] WeakRefStrings v0.5.8
  [0f1e0344] WebIO v0.8.4
  [104b5d7c] WebSockets v1.5.2
  [cc8bc4a8] Widgets v0.6.1
  [c17dfb99] WinRPM v0.4.2
  [efce3f68] WoodburyMatrices v0.4.1
  [009559a3] XGBoost v0.3.1
  [ddb6d928] YAML v0.3.2
  [c2297ded] ZMQ v1.0.0
  [a5390f91] ZipFile v0.8.3
  [2a0f44e3] Base64 
  [ade2ca70] Dates 
  [8bb1440f] DelimitedFiles 
  [8ba89e20] Distributed 
  [7b1f6079] FileWatching 
  [9fa8497b] Future 
  [b77e0a4c] InteractiveUtils 
  [76f85450] LibGit2 
  [8f399da3] Libdl 
  [37e2e46d] LinearAlgebra 
  [56ddb016] Logging 
  [d6f4376e] Markdown 
  [a63ad114] Mmap 
  [44cfe95a] Pkg 
  [de0858da] Printf 
  [9abbd945] Profile 
  [3fa0cd96] REPL 
  [9a3f8284] Random 
  [ea8e919c] SHA 
  [9e88b42a] Serialization 
  [1a1011a3] SharedArrays 
  [6462fe0b] Sockets 
  [2f01184e] SparseArrays 
  [10745b16] Statistics 
  [4607b0f0] SuiteSparse 
  [8dfed614] Test 
  [cf7118a7] UUIDs 
  [4ec0a83e] Unicode 

Here's also an extract from Manifest.toml

[[CategoricalArrays]]
deps = ["Compat", "Future", "JSON", "Missings", "Printf", "Reexport"]
git-tree-sha1 = "26601961df6afacdd16d67c1eec6cfe75e5ae9ab"
uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.5.4"

[[DecisionTree]]
deps = ["DelimitedFiles", "Distributed", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics", "Test"]
path = "/home/simon/.julia/dev/DecisionTree"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
version = "0.8.1+"

[[MLJ]]
deps = ["CSV", "CategoricalArrays", "Dates", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJModels", "Pkg", "ProgressMeter", "Random", "RecipesBase", "RemoteFiles", "Statistics", "StatsBase", "Tables"]
path = "/home/simon/.julia/dev/MLJ"
uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
version = "0.2.3"

[[MLJBase]]
deps = ["CSV", "CategoricalArrays", "Distributions", "InteractiveUtils", "Random", "SparseArrays", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "ef43664c9488e1de4a3b1bd15aa65b9d3dfc4d99"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.2.2"

[[MLJModels]]
deps = ["CategoricalArrays", "Distances", "Distributions", "LIBSVM", "LinearAlgebra", "MLJBase", "Pkg", "Random", "Requires"]
path = "/home/simon/.julia/dev/MLJModels"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
version = "0.2.3"
pfarndt commented 5 years ago

I ran into the same issue.

ablaom commented 5 years ago

Okay, the thing is, you can't presently call predict or a loss function on tasks. You have to give it actual data, as in the following example:

julia> using MLJ

julia> iris = load_iris()
julia> @load DecisionTreeClassifier
julia> tree_model = DecisionTreeClassifier(max_depth=2);
julia> tree = machine(tree_model, iris);
julia> fit!(tree, rows=train)
[ Info: Training Machine{DecisionTreeClassifier} @ 1…64.
Machine{DecisionTreeClassifier} @ 1…64

julia> X, y  = iris();

julia> yhat = predict(tree, X[test,:]);

julia> misclassification_rate(yhat, y[test])
0.022222222222222223

You can, however, just evaluate the model without unpacking the task data. That is, you could reduce the above to:

julia> using MLJ

julia> iris = load_iris()
julia> @load DecisionTreeClassifier
julia> tree_model = DecisionTreeClassifier(max_depth=2);
julia> tree = machine(tree_model, iris)
julia> evaluate!(tree,
                 measure=misclassification_rate,
                 resampling=Holdout(fraction_train=0.7,
                                  shuffle=true))
┌ Info: Evaluating using a holdout set. 
│ fraction_train=0.7 
│ shuffle=true 
│ measure=MLJ.misclassification_rate 
│ operation=StatsBase.predict 
└ Resampling from all rows. 
0.044444444444444446

Hope that helps. Thanks for raising an issue!

158 (Allow calling predict on tasks?)

ablaom commented 5 years ago

Re-opening. @sdwfrost I only checked your newer gist file iris_v2. Not yet able to reproduce your error on iris.ipynb but will look into it. Sorry, but could you just send me the output of status -p in the pkg manager or send my your Project.toml?

sdwfrost commented 5 years ago

Dear @ablaom , the output of status -p is above. Thanks for looking into this...

ablaom commented 5 years ago

Okay, I have been able to reproduce this. This seems to be related to #159.

@sdwfrost Can you try pinning CategoricalArrays to v0.5.2 and see if your problem is resolved? (In pkg manager pin CategoricalArrays@0.5.2). This worked for me.

sdwfrost commented 5 years ago

@ablaom pinning CategoricalArrays to v0.5.2 worked for me! I can certainly work with this for now.

ablaom commented 5 years ago

CategoricalArrays compatibility requirement added for now, pending resolution of #159