JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
267 stars 32 forks source link

Add check_args and drop 1.3 #499

Open theogf opened 1 year ago

theogf commented 1 year ago

Summary Similarly to Distributions.jl we do not necessarily want to check for argument correctness all the time.

Proposed changes

This adds a check_args keyword to all constructors that require it to deactivate the check for the correctness for the arguments. Additionally this drops 1.3 for 1.6 because I am lazy to write kwarg=kwarg (appears in 1.5 or 1.6 can't remember0

What alternatives have you considered? We could define a reparametrization of the parameters but we already discussed that it should be the duty of the user to do this.

Breaking changes None! Crazy right?

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 73.07% and project coverage change: -1.42 :warning:

Comparison is base (8746034) 94.25% compared to head (0c8185a) 92.83%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #499 +/- ## ========================================== - Coverage 94.25% 92.83% -1.42% ========================================== Files 52 52 Lines 1374 1396 +22 ========================================== + Hits 1295 1296 +1 - Misses 79 100 +21 ``` | [Impacted Files](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses) | Coverage Δ | | |---|---|---| | [src/utils.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL3V0aWxzLmps) | `73.07% <12.50%> (-18.39%)` | :arrow_down: | | [src/basekernels/constant.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2Jhc2VrZXJuZWxzL2NvbnN0YW50Lmps) | `100.00% <100.00%> (ø)` | | | [src/basekernels/exponential.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2Jhc2VrZXJuZWxzL2V4cG9uZW50aWFsLmps) | `100.00% <100.00%> (ø)` | | | [src/basekernels/fbm.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2Jhc2VrZXJuZWxzL2ZibS5qbA==) | `100.00% <100.00%> (ø)` | | | [src/basekernels/matern.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2Jhc2VrZXJuZWxzL21hdGVybi5qbA==) | `100.00% <100.00%> (ø)` | | | [src/basekernels/periodic.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2Jhc2VrZXJuZWxzL3BlcmlvZGljLmps) | `100.00% <100.00%> (ø)` | | | [src/basekernels/polynomial.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2Jhc2VrZXJuZWxzL3BvbHlub21pYWwuamw=) | `100.00% <100.00%> (ø)` | | | [src/basekernels/rational.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2Jhc2VrZXJuZWxzL3JhdGlvbmFsLmps) | `100.00% <100.00%> (ø)` | | | [src/basekernels/wiener.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2Jhc2VrZXJuZWxzL3dpZW5lci5qbA==) | `92.85% <100.00%> (ø)` | | | [src/kernels/scaledkernel.jl](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses#diff-c3JjL2tlcm5lbHMvc2NhbGVka2VybmVsLmps) | `88.23% <100.00%> (ø)` | | | ... and [3 more](https://codecov.io/gh/JuliaGaussianProcesses/KernelFunctions.jl/pull/499?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses) | | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaGaussianProcesses)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

devmotion commented 1 year ago

I'm a bit worried that this might cause performance regressions, or at least is suboptimal, based on the fixes that were needed in Distributions: https://github.com/JuliaStats/Distributions.jl/pull/1492

willtebbutt commented 1 year ago

Good point @devmotion . @theogf could you check some simple examples with Zygote?

theogf commented 1 year ago

I'm a bit worried that this might cause performance regressions, or at least is suboptimal, based on the fixes that were needed in Distributions: JuliaStats/Distributions.jl#1492

Then shouldn't I just reuse the @check_args from Distributions.jl ?

theogf commented 1 year ago

@devmotion do you think that would be enough?

devmotion commented 1 year ago

Did you do compare performance with Zygote?

theogf commented 1 year ago

Here are the benchmarks:

using Zygote
using BenchmarkTools
using KernelFunctions

x = [3.0]

macro old_check_args(K, param, cond, desc=string(cond))
    quote
        if !($(esc(cond)))
            throw(
                ArgumentError(
                    string(
                        $(string(K)),
                        ": ",
                        $(string(param)),
                        " = ",
                        $(esc(param)),
                        " does not ",
                        "satisfy the constraint ",
                        $(string(desc)),
                        ".",
                    ),
                ),
            )
        end
    end
end

struct OldLinearKernel{Tc<:Real} <: KernelFunctions.SimpleKernel
    c::Vector{Tc}

    function OldLinearKernel(c::Real)
        @old_check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
        return new{typeof(c)}([c])
    end
end

function f(x)
    k = LinearKernel(;c = x[1])
    sum(k.c)
end
function g(x)
    k = LinearKernel(;c = x[1], check_args=false)
    sum(k.c)
end
function h(x)
    k = OldLinearKernel(x[1])
    sum(k.c)
end

@btime Zygote.gradient($f, $x) # 15.980 μs (150 allocations: 5.89 KiB)
@btime Zygote.gradient($g, $x) #  13.853 μs (142 allocations: 5.72 KiB)
@btime Zygote.gradient($h, $x)  #  4.700 μs (51 allocations: 1.89 KiB)
devmotion commented 1 year ago

That's a quite noticeable regression. Do we know what exactly causes it?

theogf commented 1 year ago

For completeness I added the constructor

    function OldLinearKernel(c::Real; check_args=true)
        check_args && @old_check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
        return new{typeof(c)}([c])
    end

And these are the results

function h(x)
    k = OldLinearKernel(x[1])
    sum(k.c)
end
function i(x)
    k = OldLinearKernel(x[1]; check_args=false)
    sum(k.c)
end

@btime Zygote.gradient($h, $x)  # 6.086 μs (65 allocations: 2.58 KiB)
@btime Zygote.gradient($i, $x) # 10.404 μs (91 allocations: 3.38 KiB)
theogf commented 6 months ago

Looking back at this, how about using CheckArgs.jl ?