Closed sdwfrost closed 5 years ago
I had the same issue.
If you do y=CategoricalArray(y)
before fit it should work. If you update to the latest mlj version you don't need to do that.
Here you can an example I tested a couple of days ago.
https://github.com/davidbp/learn_julia/blob/master/MLJ/01_getting_started_classification.ipynb
That's not it: the following:
using MLJ
using RDatasets
iris = dataset("datasets", "iris")
iris[:,5]
gives a 150-element CategoricalArray{String,1,UInt8}
, which is exactly the same as what you get by doing y=CategoricalArray(y)
. I'm also on the latest MLJ.
Can you please send the output of ‘]status -m’ for the environment you are working in?
Thanks! I was actually looking for the command that gave the environment:
[621f4979] AbstractFFTs v0.4.1
[1520ce14] AbstractTrees v0.2.1
[79e6a3ab] Adapt v0.4.2
[7d9fca2a] Arpack v0.3.1
[4fba245c] ArrayInterface v0.1.1
[bf4720bc] AssetRegistry v0.1.0
[c52e3926] Atom v0.8.7
[67c07d97] Automa v0.8.0
[13072b0f] AxisAlgorithms v1.0.0
[39de3d68] AxisArrays v0.3.0
[fbb218c0] BSON v0.2.3
[aae01518] BandedMatrices v0.9.2
[76274a88] Bijectors v0.3.0
[9e28174c] BinDeps v0.8.10
[b99e7846] BinaryProvider v0.5.4
[37cfa864] BioCore v2.0.5
[47718e42] BioGenerics v0.1.0
[7e6ae17a] BioSequences v1.1.0
[3c28c6f8] BioSymbols v3.1.0
[a74b3585] Blosc v0.5.1
[764a87c0] BoundaryValueDiffEq v2.2.3
[e1450e63] BufferedStreams v1.0.0
[631607c0] CMake v1.1.2
[d5fb7624] CMakeWrapper v0.2.3
[00ebfdb7] CSTParser v0.5.2
[336ed68f] CSV v0.5.5
[3895d2a7] CUDAapi v0.6.3
[c5f51814] CUDAdrv v3.0.1
[be33ccc6] CUDAnative v2.1.3
[49dc2e85] Calculus v0.4.1
[7057c7e9] Cassette v0.2.3
[324d7699] CategoricalArrays v0.5.4
[aaaa29a8] Clustering v0.13.1
[53a63b46] CodeTools v0.6.4
[da1fd8a2] CodeTracking v0.5.7
[944b1d66] CodecZlib v0.5.2
[19ecbf4d] Codecs v0.5.0
[3da002f7] ColorTypes v0.8.0
[5ae59095] Colors v0.9.5
[861a8166] Combinatorics v0.7.0
[bbf7d656] CommonSubexpressions v0.2.0
[34da2185] Compat v2.1.0
[8f4d0f93] Conda v1.3.0
[d38c429a] Contour v0.5.1
[a8cc5b0e] Crayons v4.0.0
[3a865a2d] CuArrays v1.0.2
[717857b8] DSP v0.5.2
[d58978e5] Dagger v0.8.0
[0fe7c1db] DataArrays v0.7.0
[a93c6f00] DataFrames v0.18.3
[864edb3b] DataStructures v0.15.0
[e7dc6d0d] DataValues v0.4.7
[7806a523] DecisionTree v0.8.1+ [`~/.julia/dev/DecisionTree`]
[bcd4f6db] DelayDiffEq v5.3.0
[2b5f629d] DiffEqBase v5.10.1
[459566f4] DiffEqCallbacks v2.5.2
[01453d9d] DiffEqDiffTools v0.10.1
[5a0ffddc] DiffEqFinancial v2.1.0
[c894b116] DiffEqJump v6.1.1
[78ddff82] DiffEqMonteCarlo v0.14.0
[77a26b50] DiffEqNoiseProcess v3.3.1
[9fdde737] DiffEqOperators v3.5.0
[055956cb] DiffEqPhysics v3.1.0
[163ba53b] DiffResults v0.0.4
[b552c78f] DiffRules v0.0.10
[0c46a032] DifferentialEquations v6.4.0
[c619ae07] DimensionalPlotRecipes v0.2.0
[b4f34e82] Distances v0.8.0
[31c24e10] Distributions v0.20.0
[33d173f1] DocSeeker v0.2.0
[ffbed154] DocStringExtensions v0.7.0
[497a8b3b] DoubleFloats v0.9.1
[fdbdab4c] ElasticArrays v0.4.0
[2904ab23] ElasticPDMats v0.2.1
[d4d017d3] ExponentialUtilities v1.4.0
[8f5d6c58] EzXML v0.9.1
[7a1cc6ca] FFTW v0.2.4
[442a2c76] FastGaussQuadrature v0.3.3
[a0c94c4b] FastaIO v0.5.0
[5789e2e9] FileIO v1.0.6
[1a297f60] FillArrays v0.6.3
[53c48c17] FixedPointNumbers v0.5.3
[587475ba] Flux v0.8.3
[f6369f11] ForwardDiff v0.10.3
[da1fdf0e] FreqTables v0.3.1
[069b7b12] FunctionWrappers v1.0.0
[de31a74c] FunctionalCollections v0.5.0
[38e38edf] GLM v1.1.1
[0c68f7d7] GPUArrays v0.7.1
[28b8d3ca] GR v0.40.0
[92fee26a] GZip v0.5.0
[891a1506] GaussianProcesses v0.9.0
[01680d73] GenericSVD v0.2.1
[c145ed77] GenericSchur v0.2.3
[4d00f742] GeometryTypes v0.7.5
[c27321d9] Glob v1.2.0
[f67ccb44] HDF5 v0.11.1
[cd3eb016] HTTP v0.8.2
[0862f596] HTTPClient v0.2.1
[9fb69e20] Hiccup v0.2.2
[d9be37ee] Homebrew v0.7.1
[7073ff75] IJulia v1.18.1
[1cb3b9ac] IndexableBitVectors v1.0.0
[6deec6e2] IndexedTables v0.12.0
[83e8ac13] IniFile v0.5.0
[a98d9a8b] Interpolations v0.12.2
[8197267c] IntervalSets v0.3.1
[524e6230] IntervalTrees v1.0.0
[c8e1da08] IterTools v1.1.1
[42fd0dbc] IterativeSolvers v0.8.1
[82899510] IteratorInterfaceExtensions v1.0.0
[682c06a0] JSON v0.20.0
[a93385a2] JuliaDB v0.12.0
[aa1ae85d] JuliaInterpreter v0.6.0
[e5e0dc1b] Juno v0.7.0
[5ab0869b] KernelDensity v0.5.1
[2d691ee1] LIBLINEAR v0.5.1
[b1bec4e5] LIBSVM v0.3.1
[929cbde3] LLVM v1.1.1
[7c4cb9fa] LNR v0.2.0
[b964fa9f] LaTeXStrings v1.0.3
[50d2b5c4] Lazy v0.13.2
[5078a376] LazyArrays v0.9.0
[7f8f8fb0] LearnBase v0.2.2
[b27032c2] LibCURL v0.5.1
[522f3ed2] LibExpat v0.5.0
[b13ce0c6] LibSndFile v2.0.0
[6f1fad26] Libtask v0.3.0
[2ec943e9] Libz v1.0.0
[d3d80556] LineSearches v7.0.1
[30fc2ffe] LossFunctions v0.5.1
[c7f686f2] MCMCChains v0.3.10
[add582a8] MLJ v0.2.3 [`~/.julia/dev/MLJ`]
[a7f614a8] MLJBase v0.2.2
[d491faf4] MLJModels v0.2.3 [`~/.julia/dev/MLJModels`]
[1914dd2f] MacroTools v0.5.0
[dbb5928d] MappedArrays v0.2.1
[a3b82374] MatrixFactorizations v0.0.4
[739be429] MbedTLS v0.6.8
[442fdcdd] Measures v0.3.0
[e89f7d12] Media v0.5.0
[f9f48841] MemPool v0.2.0
[e1d29d7a] Missings v0.4.1
[78c3b35d] Mocking v0.5.7
[46d2c3a1] MuladdMacro v0.2.1
[f9640e96] MultiScaleArrays v1.4.0
[d41bc354] NLSolversBase v7.3.1
[2774e3e8] NLsolve v4.0.0
[872c559c] NNlib v0.6.0
[77ba4419] NaNMath v0.3.2
[9bbee03b] NaiveBayes v0.4.0
[86f7a689] NamedArrays v0.9.2
[b8a86587] NearestNeighbors v0.4.3
[4d1e1d77] Nullables v0.0.8
[510215fc] Observables v0.2.3
[6fe1bfb0] OffsetArrays v0.11.0
[a15396b6] OnlineStats v0.23.0
[925886fa] OnlineStatsBase v0.10.2
[429524aa] Optim v0.18.1
[bac558e1] OrderedCollections v1.1.0
[1dea7af3] OrdinaryDiffEq v5.8.1
[90014a1f] PDMats v0.9.7
[d96e819e] Parameters v0.10.3
[69de0a69] Parsers v0.3.5
[06bb1623] PenaltyFunctions v0.1.2
[fa939f87] Pidfile v1.1.0
[ccf2f8ad] PlotThemes v0.3.0
[995b91a9] PlotUtils v0.5.8
[91a5bcdd] Plots v0.25.1
[e409e4f3] PoissonRandom v0.4.0
[f27b6e38] Polynomials v0.5.2
[2dfb63ee] PooledArrays v0.5.2
[85a6dd25] PositiveFactorizations v0.2.2
[92933f4c] ProgressMeter v1.0.0
[438e738f] PyCall v1.91.2
[d330b81b] PyPlot v2.8.1
[1fd47b50] QuadGK v2.0.4
[be4d8f0f] Quadmath v0.4.0
[df47a6cb] RData v0.6.0
[ce6b1742] RDatasets v0.6.2
[e6cf234a] RandomNumbers v1.3.0
[b3c3ace0] RangeArrays v0.3.1
[c84ed2f1] Ratios v0.3.1
[3cdcf5f2] RecipesBase v0.6.0
[731186ca] RecursiveArrayTools v0.20.0
[f2c3362d] RecursiveFactorization v0.0.1
[189a3867] Reexport v0.2.0
[cbe49d4c] RemoteFiles v0.2.1
[ae029012] Requires v0.5.2
[ae5879a3] ResettableStacks v0.6.0
[79098fc4] Rmath v0.5.0
[f2b01f46] Roots v0.8.1
[bd7594eb] SampledSignals v2.0.0
[3646fa90] ScikitLearn v0.5.0
[6e75b9c4] ScikitLearnBase v0.4.1
[992d4aef] Showoff v0.2.1
[b85f4697] SoftGlobalScope v1.0.10
[a2af1166] SortingAlgorithms v0.3.1
[276daf66] SpecialFunctions v0.7.2
[90137ffa] StaticArrays v0.11.0
[2913bbd2] StatsBase v0.30.0
[4c63d2b9] StatsFuns v0.8.0
[3eaba693] StatsModels v0.5.0
[9672c7b4] SteadyStateDiffEq v1.4.0
[789caeaf] StochasticDiffEq v6.4.0
[88034a9c] StringDistances v0.3.2
[09ab397b] StructArrays v0.3.4
[c3572dad] Sundials v3.6.0
[fd094767] Suppressor v0.1.1
[7522ee7d] SweepOperator v0.2.0
[3783bdb8] TableTraits v1.0.0
[382cd787] TableTraitsUtils v0.4.0
[bd369af6] Tables v0.2.5
[e0df1984] TextParse v0.9.1
[f269a46b] TimeZones v0.9.1
[a759f4b9] TimerOutputs v0.5.0
[0796e94c] Tokenize v0.5.3
[37b6cedf] Traceur v0.3.0
[9f7883ad] Tracker v0.2.2
[3bb67fe8] TranscodingStreams v0.9.4
[a2a6695c] TreeViews v0.3.0
[fce5fe82] Turing v0.6.17
[7200193e] Twiddle v1.1.0
[30578b45] URIParser v0.4.0
[1986cc42] Unitful v0.15.0
[81def892] VersionParsing v1.1.3
[ea10d353] WeakRefStrings v0.5.8
[0f1e0344] WebIO v0.8.4
[104b5d7c] WebSockets v1.5.2
[cc8bc4a8] Widgets v0.6.1
[c17dfb99] WinRPM v0.4.2
[efce3f68] WoodburyMatrices v0.4.1
[009559a3] XGBoost v0.3.1
[ddb6d928] YAML v0.3.2
[c2297ded] ZMQ v1.0.0
[a5390f91] ZipFile v0.8.3
[2a0f44e3] Base64
[ade2ca70] Dates
[8bb1440f] DelimitedFiles
[8ba89e20] Distributed
[7b1f6079] FileWatching
[9fa8497b] Future
[b77e0a4c] InteractiveUtils
[76f85450] LibGit2
[8f399da3] Libdl
[37e2e46d] LinearAlgebra
[56ddb016] Logging
[d6f4376e] Markdown
[a63ad114] Mmap
[44cfe95a] Pkg
[de0858da] Printf
[9abbd945] Profile
[3fa0cd96] REPL
[9a3f8284] Random
[ea8e919c] SHA
[9e88b42a] Serialization
[1a1011a3] SharedArrays
[6462fe0b] Sockets
[2f01184e] SparseArrays
[10745b16] Statistics
[4607b0f0] SuiteSparse
[8dfed614] Test
[cf7118a7] UUIDs
[4ec0a83e] Unicode
Here's also an extract from Manifest.toml
[[CategoricalArrays]]
deps = ["Compat", "Future", "JSON", "Missings", "Printf", "Reexport"]
git-tree-sha1 = "26601961df6afacdd16d67c1eec6cfe75e5ae9ab"
uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.5.4"
[[DecisionTree]]
deps = ["DelimitedFiles", "Distributed", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics", "Test"]
path = "/home/simon/.julia/dev/DecisionTree"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
version = "0.8.1+"
[[MLJ]]
deps = ["CSV", "CategoricalArrays", "Dates", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJModels", "Pkg", "ProgressMeter", "Random", "RecipesBase", "RemoteFiles", "Statistics", "StatsBase", "Tables"]
path = "/home/simon/.julia/dev/MLJ"
uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
version = "0.2.3"
[[MLJBase]]
deps = ["CSV", "CategoricalArrays", "Distributions", "InteractiveUtils", "Random", "SparseArrays", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "ef43664c9488e1de4a3b1bd15aa65b9d3dfc4d99"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.2.2"
[[MLJModels]]
deps = ["CategoricalArrays", "Distances", "Distributions", "LIBSVM", "LinearAlgebra", "MLJBase", "Pkg", "Random", "Requires"]
path = "/home/simon/.julia/dev/MLJModels"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
version = "0.2.3"
I ran into the same issue.
Okay, the thing is, you can't presently call predict
or a loss function on tasks. You have to give it actual data, as in the following example:
julia> using MLJ
julia> iris = load_iris()
julia> @load DecisionTreeClassifier
julia> tree_model = DecisionTreeClassifier(max_depth=2);
julia> tree = machine(tree_model, iris);
julia> fit!(tree, rows=train)
[ Info: Training Machine{DecisionTreeClassifier} @ 1…64.
Machine{DecisionTreeClassifier} @ 1…64
julia> X, y = iris();
julia> yhat = predict(tree, X[test,:]);
julia> misclassification_rate(yhat, y[test])
0.022222222222222223
You can, however, just evaluate the model without unpacking the task data. That is, you could reduce the above to:
julia> using MLJ
julia> iris = load_iris()
julia> @load DecisionTreeClassifier
julia> tree_model = DecisionTreeClassifier(max_depth=2);
julia> tree = machine(tree_model, iris)
julia> evaluate!(tree,
measure=misclassification_rate,
resampling=Holdout(fraction_train=0.7,
shuffle=true))
┌ Info: Evaluating using a holdout set.
│ fraction_train=0.7
│ shuffle=true
│ measure=MLJ.misclassification_rate
│ operation=StatsBase.predict
└ Resampling from all rows.
0.044444444444444446
Hope that helps. Thanks for raising an issue!
predict
on tasks?)Re-opening.
@sdwfrost I only checked your newer gist file iris_v2. Not yet able to reproduce your error on iris.ipynb but will look into it. Sorry, but could you just send me the output of status -p
in the pkg manager or send my your Project.toml?
Dear @ablaom , the output of status -p
is above. Thanks for looking into this...
Okay, I have been able to reproduce this. This seems to be related to #159.
@sdwfrost Can you try pinning CategoricalArrays to v0.5.2 and see if your problem is resolved? (In pkg manager pin CategoricalArrays@0.5.2
). This worked for me.
@ablaom pinning CategoricalArrays to v0.5.2 worked for me! I can certainly work with this for now.
CategoricalArrays compatibility requirement added for now, pending resolution of #159
In trying to run the Iris/DecisionTree example, I get the following error in fit:
I've tried updating DecisionTree, MLJ and MLJModels to the latest versions, and I believe I'm looking at the right docs...I've tried using both
RDatasets
and the task interface.Gist here