JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

`PeriodicKernel` does not play with AD #527

Closed simsurace closed 6 months ago

simsurace commented 11 months ago

I don't know since when this has been broken:

julia> using TestEnv; TestEnv.activate()
"/tmp/jl_cThCwu/Project.toml"

julia> using KernelFunctions, Zygote

julia> build_kernel(θ) = PeriodicKernel(; r = [θ.r])
build_kernel (generic function with 1 method)

julia> θ = (r = 1., )
(r = 1.0,)

julia> testfun(θ) = sum(kernelmatrix(build_kernel(θ), rand(10)))
testfun (generic function with 1 method)

julia> Zygote.gradient(testfun, θ)
ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/array.jl:88
  [3] (::Zygote.var"#555#556"{Matrix{Float64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/array.jl:100
  [4] (::Zygote.var"#2659#back#557"{Zygote.var"#555#556"{Matrix{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [5] Pullback
    @ ~/.julia/packages/Distances/yhVAl/src/generic.jl:219 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(Distances._pairwise!), KernelFunctions.Sinus{Float64}, Matrix{Float64}, Matrix{Float64}}, Any})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/packages/Distances/yhVAl/src/generic.jl:287 [inlined]
  [8] (::Zygote.Pullback{Tuple{Distances.var"##pairwise!#2", Int64, typeof(StatsAPI.pairwise!), KernelFunctions.Sinus{Float64}, Matrix{Float64}, Matrix{Float64}}, Any})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/Distances/yhVAl/src/generic.jl:275 [inlined]
 [10] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:dims,), Tuple{Int64}}, typeof(StatsAPI.pairwise!), KernelFunctions.Sinus{Float64}, Matrix{Float64}, Matrix{Float64}}, Any})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/Distances/yhVAl/src/generic.jl:330 [inlined]
 [12] (::Zygote.Pullback{Tuple{Distances.var"##pairwise#4", Int64, typeof(StatsAPI.pairwise), KernelFunctions.Sinus{Float64}, Matrix{Float64}}, Any})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/Distances/yhVAl/src/generic.jl:324 [inlined]
 [14] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:dims,), Tuple{Int64}}, typeof(StatsAPI.pairwise), KernelFunctions.Sinus{Float64}, Matrix{Float64}}, Any})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/dev/KernelFunctions.jl/src/distances/pairwise.jl:16 [inlined]
 [16] (::Zygote.Pullback{Tuple{typeof(KernelFunctions.pairwise), KernelFunctions.Sinus{Float64}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:dims,), Tuple{Int64}}, typeof(StatsAPI.pairwise), KernelFunctions.Sinus{Float64}, Matrix{Float64}}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/dev/KernelFunctions.jl/src/matrix/kernelmatrix.jl:150 [inlined]
 [18] (::Zygote.Pullback{Tuple{typeof(kernelmatrix), PeriodicKernel{Float64}, Vector{Float64}}, Tuple{Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, 1, Tuple{Matrix{Float64}}, Tuple{Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Matrix{Tuple{Float64, Zygote.Pullback{Tuple{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, Float64}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:κ, Zygote.Context{false}, KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, PeriodicKernel{Float64}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.kappa), PeriodicKernel{Float64}, Float64}, Tuple{Zygote.ZBack{ChainRules.var"#exp_pullback#1324"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{ChainRules.var"#times_pullback2#1350"{Float64, Float64}}}}}}}}}}, Zygote.var"#2214#back#313"{Zygote.Jnew{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.metric), PeriodicKernel{Float64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:r, Zygote.Context{false}, PeriodicKernel{Float64}, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{KernelFunctions.Sinus}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{Type{KernelFunctions.Sinus{Float64}}, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2214#back#313"{Zygote.Jnew{KernelFunctions.Sinus{Float64}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Float64}}, Vector{Float64}}, Any}}}}}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.pairwise), KernelFunctions.Sinus{Float64}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:dims,), Tuple{Int64}}, typeof(StatsAPI.pairwise), KernelFunctions.Sinus{Float64}, Matrix{Float64}}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}}}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [19] Pullback
    @ ./REPL[5]:1 [inlined]
 [20] (::Zygote.Pullback{Tuple{typeof(testfun), NamedTuple{(:r,), Tuple{Float64}}}, Tuple{Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(build_kernel), NamedTuple{(:r,), Tuple{Float64}}}, Tuple{Zygote.ZBack{ChainRules.var"#vect_pullback#1373"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:r,), Tuple{Vector{Float64}}}, Type{PeriodicKernel}}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:r,)}}, Tuple{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:r,), Tuple{Vector{Float64}}}}, Tuple{Vector{Float64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:r,), Tuple{Vector{Float64}}}, Nothing, true}}}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:r, Zygote.Context{false}, NamedTuple{(:r,), Tuple{Float64}}, Float64}}}}, Zygote.Pullback{Tuple{typeof(kernelmatrix), PeriodicKernel{Float64}, Vector{Float64}}, Tuple{Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, 1, Tuple{Matrix{Float64}}, Tuple{Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Matrix{Tuple{Float64, Zygote.Pullback{Tuple{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, Float64}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:κ, Zygote.Context{false}, KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, PeriodicKernel{Float64}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.kappa), PeriodicKernel{Float64}, Float64}, Tuple{Zygote.ZBack{ChainRules.var"#exp_pullback#1324"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{ChainRules.var"#times_pullback2#1350"{Float64, Float64}}}}}}}}}}, Zygote.var"#2214#back#313"{Zygote.Jnew{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.metric), PeriodicKernel{Float64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:r, Zygote.Context{false}, PeriodicKernel{Float64}, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{KernelFunctions.Sinus}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{Type{KernelFunctions.Sinus{Float64}}, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2214#back#313"{Zygote.Jnew{KernelFunctions.Sinus{Float64}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Float64}}, Vector{Float64}}, Any}}}}}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.pairwise), KernelFunctions.Sinus{Float64}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:dims,), Tuple{Int64}}, typeof(StatsAPI.pairwise), KernelFunctions.Sinus{Float64}, Matrix{Float64}}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#rand_pullback#2182"{Tuple{Int64}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(testfun), NamedTuple{(:r,), Tuple{Float64}}}, Tuple{Zygote.var"#3027#back#782"{Zygote.var"#776#780"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(build_kernel), NamedTuple{(:r,), Tuple{Float64}}}, Tuple{Zygote.ZBack{ChainRules.var"#vect_pullback#1373"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:r,), Tuple{Vector{Float64}}}, Type{PeriodicKernel}}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:r,)}}, Tuple{Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:r,), Tuple{Vector{Float64}}}}, Tuple{Vector{Float64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:r,), Tuple{Vector{Float64}}}, Nothing, true}}}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:r, Zygote.Context{false}, NamedTuple{(:r,), Tuple{Float64}}, Float64}}}}, Zygote.Pullback{Tuple{typeof(kernelmatrix), PeriodicKernel{Float64}, Vector{Float64}}, Tuple{Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, 1, Tuple{Matrix{Float64}}, Tuple{Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Matrix{Tuple{Float64, Zygote.Pullback{Tuple{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, Float64}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:κ, Zygote.Context{false}, KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, PeriodicKernel{Float64}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.kappa), PeriodicKernel{Float64}, Float64}, Tuple{Zygote.ZBack{ChainRules.var"#exp_pullback#1324"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, Zygote.ZBack{ChainRules.var"#times_pullback2#1350"{Float64, Float64}}}}}}}}}}, Zygote.var"#2214#back#313"{Zygote.Jnew{KernelFunctions.var"#68#69"{PeriodicKernel{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.metric), PeriodicKernel{Float64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:r, Zygote.Context{false}, PeriodicKernel{Float64}, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{KernelFunctions.Sinus}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{Type{KernelFunctions.Sinus{Float64}}, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2214#back#313"{Zygote.Jnew{KernelFunctions.Sinus{Float64}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{Vector{Float64}}, Vector{Float64}}, Any}}}}}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.pairwise), KernelFunctions.Sinus{Float64}, Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:dims,), Tuple{Int64}}, typeof(StatsAPI.pairwise), KernelFunctions.Sinus{Float64}, Matrix{Float64}}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#rand_pullback#2182"{Tuple{Int64}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
 [22] gradient(f::Function, args::NamedTuple{(:r,), Tuple{Float64}})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:97
 [23] top-level scope
    @ REPL[6]:1
simsurace commented 6 months ago

@JuliaRegistrator register()

JuliaRegistrator commented 6 months ago

Error while trying to register: Version 0.10.61 already exists