STOR-i / GaussianProcesses.jl

A Julia package for Gaussian Processes
https://stor-i.github.io/GaussianProcesses.jl/latest/
Other
308 stars 53 forks source link

GPE ScikitLearn fit! gives error #149

Closed OkonSamuel closed 4 years ago

OkonSamuel commented 4 years ago

Hello. I recently tried to fit a GPE using the ScikitLearn.jl interface which gave me an error The code is shown below

using GaussianProcesses, ScikitLearn
gp=GPE(mean=MeanZero(),kernel=SE(0.0,0.0), logNoise=-1.0)
X=rand(150,4)
y=rand(150)
ScikitLearn.fit!(gp,X,y)

this gives the following error

ERROR: MethodError: no method matching fit!(::GPE{Array{Float64,2},Array{Float64,1},MeanZero,SEIso{Float64},GaussianProcesses.FullCovariance,GaussianProcesses.IsotropicData{Array{Float64,2}},PDMats.PDMat{Float64,Array{Float64,2}},GaussianProcesses.Scalar{Float64}}, ::Adjoint{Float64,Array{Float64,2}}, ::Array{Float64,1})
Closest candidates are:
  fit!(::GPE, ::AbstractArray{T,1} where T, ::AbstractArray{T,1} where T) at /home/okonsamuel/.julia/packages/GaussianProcesses/sed6i/src/GPE.jl:138
  fit!(::GPE{X,Y,M,K,CS,D,P,NOI} where NOI<:GaussianProcesses.Param where P<:PDMats.AbstractPDMat where D<:KernelData where CS<:GaussianProcesses.CovarianceStrategy where K<:Kernel where M<:GaussianProcesses.Mean, ::X, ::Y) where {X, Y} at /home/okonsamuel/.julia/packages/GaussianProcesses/sed6i/src/GPE.jl:129
Stacktrace:
 [1] fit!(::GPE{Array{Float64,2},Array{Float64,1},MeanZero,SEIso{Float64},GaussianProcesses.FullCovariance,GaussianProcesses.IsotropicData{Array{Float64,2}},PDMats.PDMat{Float64,Array{Float64,2}},GaussianProcesses.Scalar{Float64}}, ::Array{Float64,2}, ::Array{Float64,1}) at /home/okonsamuel/.julia/packages/GaussianProcesses/sed6i/src/ScikitLearn.jl:7
 [2] top-level scope at REPL[147]:1

On further investigation i found out that the fix is to change ScikitLearnBase.fit!(gp::GPE, X::AbstractMatrix, y::AbstractVector) = fit!(gp, X', y) to ScikitLearnBase.fit!(gp::GPE, X::AbstractMatrix, y::AbstractVector) = fit!(gp, permutedims(X), y) where X and y are strictly of types Array{Float64,2} and Array{Float64,2} respectively due to the defaults used in the intialization.