JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

Failures on master #526

Open simsurace opened 11 months ago

simsurace commented 11 months ago

The BaseKernels failures on master can be traced back to moving ChainRules from 1.52.0 to 1.53.0. I think some weird interaction betwen ChainRules and Zygote pullbacks is at play. The upside is that they are resolved by using https://github.com/JuliaStats/Distances.jl/pull/246.

simsurace commented 11 months ago

Actually, I re-tested now. The failures of the MaternKernel on master can only be fixed by doing one of the following:

simsurace commented 10 months ago

The remaining failures are due to a non-deterministic nature of allocations, which showed up in Julia 1.9 vs 1.8 (see #530). It remains to check where the occasional additional allocation comes from, or else to relax the tests to allow deviations of +/- 1 allocations.

simsurace commented 6 months ago

Ok, we are getting there! Any ideas on what the last remaining failure is about?

Zygote: Error During Test at /home/runner/work/KernelFunctions.jl/KernelFunctions.jl/test/transform/selecttransform.jl:75
  Got exception outside of a @test
  ArgumentError: unable to check bounds for indices of type Symbol
  Stacktrace:
    [1] checkindex(::Type{Bool}, inds::Base.OneTo{Int64}, i::Symbol)
      @ Base ./abstractarray.jl:756
    [2] checkindex
      @ Base ./abstractarray.jl:773 [inlined]
    [3] checkbounds_indices
      @ Base ./abstractarray.jl:725 [inlined]
    [4] checkbounds
      @ Base ./abstractarray.jl:678 [inlined]
    [5] checkbounds
      @ Base ./abstractarray.jl:699 [inlined]
    [6] view(::Matrix{Float64}, ::Vector{Symbol}, ::Base.Slice{Base.OneTo{Int64}})
      @ Base ./subarray.jl:179
    [7] ∇getindex!(::Matrix{Float64}, ::Matrix{Float64}, ::Vector{Symbol}, ::Vararg{Any})
      @ ChainRules ~/.julia/packages/ChainRules/Gw0tZ/src/rulesets/Base/indexing.jl:151
    [8] ∇getindex(::AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, ::Matrix{Float64}, ::Vector{Symbol}, ::Vararg{Any})
      @ ChainRules ~/.julia/packages/ChainRules/Gw0tZ/src/rulesets/Base/indexing.jl:89
    [9] #1579
      @ ~/.julia/packages/ChainRules/Gw0tZ/src/rulesets/Base/indexing.jl:69 [inlined]
   [10] unthunk
      @ ~/.julia/packages/ChainRulesCore/PvTbU/src/tangent_types/thunks.jl:204 [inlined]
   [11] unthunk
      @ ~/.julia/packages/ChainRulesCore/PvTbU/src/tangent_types/thunks.jl:237 [inlined]
   [12] wrap_chainrules_output
      @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:110 [inlined]
   [13] map
      @ ./tuple.jl:293 [inlined]
   [14] map
      @ ./tuple.jl:294 [inlined]
   [15] wrap_chainrules_output
      @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:111 [inlined]
   [16] ZBack
      @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
   [17] _map
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/src/transform/selecttransform.jl:28 [inlined]
   [18] kernelmatrix
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/src/kernels/transformedkernel.jl:113 [inlined]
   [19] #kernelmatrix#78
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/src/matrix/kernelmatrix.jl:189 [inlined]
   [20] (::Zygote.Pullback{Tuple{KernelFunctions.var"##kernelmatrix#78", Int64, typeof(kernelmatrix), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}, Tuple{Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{obsdim::Int64}, typeof(KernelFunctions.vec_of_vecs), AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}, Any}, Zygote.Pullback{Tuple{typeof(kernelmatrix), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, ColVecs{Float64, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, AxisVector{Float64, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{Axis{:row, Vector{Symbol}}}}}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:transform, Zygote.Context{false}, TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, SelectTransform{Vector{Symbol}}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions._map), SelectTransform{Vector{Symbol}}, ColVecs{Float64, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, AxisVector{Float64, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{Axis{:row, Vector{Symbol}}}}}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:select, Zygote.Context{false}, SelectTransform{Vector{Symbol}}, Vector{Symbol}}}, Zygote.ZBack{ChainRules.var"#view_pullback#1589"{AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, Tuple{Vector{Symbol}, Colon}, Tuple{NoTangent, NoTangent}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions._wrap), AxisMatrix{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, Type{ColVecs}}, Tuple{Zygote.ZBack{KernelFunctions.var"#ColVecs_pullback#159"}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:X, Zygote.Context{false}, ColVecs{Float64, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, AxisVector{Float64, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{Axis{:row, Vector{Symbol}}}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}}}}, Zygote.Pullback{Tuple{typeof(kernelmatrix), SqExponentialKernel{Euclidean}, ColVecs{Float64, AxisMatrix{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, AxisVector{Float64, SubArray{Float64, 1, Matrix{Float64}, Tuple{Vector{Int64}, Int64}, false}, Tuple{Axis{:row, Vector{Symbol}}}}}}, Tuple{Zygote.var"#2210#back#313"{Zygote.Jnew{KernelFunctions.var"#68#69"{SqExponentialKernel{Euclidean}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(metric), SqExponentialKernel{Euclidean}}, Tuple{Zygote.Pullback{Tuple{Type{SqEuclidean}}, Tuple{}}}}, Zygote.var"#2845#back#673"{Zygote.var"#map_back#667"{KernelFunctions.var"#68#69"{SqExponentialKernel{Euclidean}}, 1, Tuple{Matrix{Float64}}, Tuple{Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Matrix{Tuple{Float64, Zygote.Pullback{Tuple{KernelFunctions.var"#68#69"{SqExponentialKernel{Euclidean}}, Float64}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:κ, Zygote.Context{false}, KernelFunctions.var"#68#69"{SqExponentialKernel{Euclidean}}, SqExponentialKernel{Euclidean}}}, Zygote.Pullback{Tuple{typeof(kappa), SqExponentialKernel{Euclidean}, Float64}, Tuple{Zygote.ZBack{ChainRules.var"#/_pullback#1317"{Float64, Float64, ProjectTo{Float64, @NamedTuple{}}, ProjectTo{Float64, @NamedTuple{}}}}, Zygote.ZBack{ChainRules.var"#exp_pullback#1301"{Float64, ProjectTo{Float64, @NamedTuple{}}}}, Zygote.ZBack{ChainRules.var"#-_pullback#1325"{Int64, ProjectTo{Float64, @NamedTuple{}}}}}}}}}}}}, Zygote.Pullback{Tuple{typeof(KernelFunctions.pairwise), SqEuclidean, ColVecs{Float64, AxisMatrix{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, AxisVector{Float64, SubArray{Float64, 1, Matrix{Float64}, Tuple{Vector{Int64}, Int64}, false}, Tuple{Axis{:row, Vector{Symbol}}}}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.var"#2220#back#315"{Zygote.Jnew{@NamedTuple{dims::Int64}, Nothing, true}}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{AxisMatrix{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, typeof(identity)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:X, Zygote.Context{false}, ColVecs{Float64, AxisMatrix{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, AxisVector{Float64, SubArray{Float64, 1, Matrix{Float64}, Tuple{Vector{Int64}, Int64}, false}, Tuple{Axis{:row, Vector{Symbol}}}}}, AxisMatrix{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}}, Zygote.var"#2013#back#204"{typeof(identity)}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:kernel, Zygote.Context{false}, TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, SqExponentialKernel{Euclidean}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:obsdim,)}}, Tuple{Int64}}, Tuple{Zygote.var"#2220#back#315"{Zygote.Jnew{@NamedTuple{obsdim::Int64}, Nothing, true}}}}, Zygote.var"#2013#back#204"{typeof(identity)}}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
      @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
   [21] kernelmatrix
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/src/matrix/kernelmatrix.jl:188 [inlined]
   [22] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{obsdim::Int64}, typeof(kernelmatrix), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}, Any})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
      @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
   [23] testfunction
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/test/test_utils.jl:83 [inlined]
   [24] (::Zygote.Pullback{Tuple{typeof(testfunction), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, Int64}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:obsdim,)}}, Tuple{Int64}}, Tuple{Zygote.var"#2220#back#315"{Zygote.Jnew{@NamedTuple{obsdim::Int64}, Nothing, true}}}}, Zygote.var"#2989#back#768"{Zygote.var"#762#766"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{obsdim::Int64}, typeof(kernelmatrix), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}, Any}}})(Δ::Float64)
      @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
   [25] #193
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/test/transform/selecttransform.jl:80 [inlined]
   [26] (::Zygote.Pullback{Tuple{var"#193#218"{TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:ta_row, Zygote.Context{false}, var"#193#218"{TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}}, TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}}}, Zygote.Pullback{Tuple{typeof(testfunction), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, Int64}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:obsdim,)}}, Tuple{Int64}}, Tuple{Zygote.var"#2220#back#315"{Zygote.Jnew{@NamedTuple{obsdim::Int64}, Nothing, true}}}}, Zygote.var"#2989#back#768"{Zygote.var"#762#766"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{obsdim::Int64}, typeof(kernelmatrix), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}, Any}}}}})(Δ::Float64)
      @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
   [27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#193#218"{TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:ta_row, Zygote.Context{false}, var"#193#218"{TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}}, TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}}}, Zygote.Pullback{Tuple{typeof(testfunction), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}, Int64}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:obsdim,)}}, Tuple{Int64}}, Tuple{Zygote.var"#2220#back#315"{Zygote.Jnew{@NamedTuple{obsdim::Int64}, Nothing, true}}}}, Zygote.var"#2989#back#768"{Zygote.var"#762#766"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{obsdim::Int64}, typeof(kernelmatrix), TransformedKernel{SqExponentialKernel{Euclidean}, SelectTransform{Vector{Symbol}}}, AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}}}, Any}}}}}})(Δ::Float64)
      @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
   [28] gradient(f::Function, args::AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}})
      @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
   [29] gradient(f::Function, ::Val{:Zygote}, args::AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}})
      @ Main ~/work/KernelFunctions.jl/KernelFunctions.jl/test/test_utils.jl:48
   [30] gradient(f::Function, s::Symbol, args::AxisMatrix{Float64, Matrix{Float64}, Tuple{Axis{:row, Vector{Symbol}}, Axis{:col, Vector{Symbol}}}})
      @ Main ~/work/KernelFunctions.jl/KernelFunctions.jl/test/test_utils.jl:45
   [31] macro expansion
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/test/transform/selecttransform.jl:79 [inlined]
   [32] macro expansion
      @ /opt/hostedtoolcache/julia/1.10.0/x64/share/julia/stdlib/v1.10/Test/src/Test.jl:1669 [inlined]
   [33] macro expansion
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/test/transform/selecttransform.jl:75 [inlined]
   [34] macro expansion
      @ /opt/hostedtoolcache/julia/1.10.0/x64/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [35] top-level scope
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/test/transform/selecttransform.jl:2
   [36] include(fname::String)
      @ Base.MainInclude ./client.jl:489
   [37] macro expansion
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/test/runtests.jl:79 [inlined]
   [38] macro expansion
      @ /opt/hostedtoolcache/julia/1.10.0/x64/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [39] macro expansion
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/test/runtests.jl:69 [inlined]
   [40] macro expansion
      @ /opt/hostedtoolcache/julia/1.10.0/x64/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [41] top-level scope
      @ ~/work/KernelFunctions.jl/KernelFunctions.jl/test/runtests.jl:67
   [42] include(fname::String)
      @ Base.MainInclude ./client.jl:489
   [43] top-level scope
      @ none:6
   [44] eval
      @ Core ./boot.jl:385 [inlined]
   [45] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:291
   [46] _start()
      @ Base ./client.jl:552
devmotion commented 6 months ago

Probably some ChainRules rule for getindex that can't cope with Symbols but is hit in our AxisArrays tests.