Open theogf opened 1 year ago
Patch coverage: 73.07
% and project coverage change: -1.42
:warning:
Comparison is base (
8746034
) 94.25% compared to head (0c8185a
) 92.83%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
I'm a bit worried that this might cause performance regressions, or at least is suboptimal, based on the fixes that were needed in Distributions: https://github.com/JuliaStats/Distributions.jl/pull/1492
Good point @devmotion . @theogf could you check some simple examples with Zygote?
I'm a bit worried that this might cause performance regressions, or at least is suboptimal, based on the fixes that were needed in Distributions: JuliaStats/Distributions.jl#1492
Then shouldn't I just reuse the @check_args
from Distributions.jl ?
@devmotion do you think that would be enough?
Did you do compare performance with Zygote?
Here are the benchmarks:
using Zygote
using BenchmarkTools
using KernelFunctions
x = [3.0]
macro old_check_args(K, param, cond, desc=string(cond))
quote
if !($(esc(cond)))
throw(
ArgumentError(
string(
$(string(K)),
": ",
$(string(param)),
" = ",
$(esc(param)),
" does not ",
"satisfy the constraint ",
$(string(desc)),
".",
),
),
)
end
end
end
struct OldLinearKernel{Tc<:Real} <: KernelFunctions.SimpleKernel
c::Vector{Tc}
function OldLinearKernel(c::Real)
@old_check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}([c])
end
end
function f(x)
k = LinearKernel(;c = x[1])
sum(k.c)
end
function g(x)
k = LinearKernel(;c = x[1], check_args=false)
sum(k.c)
end
function h(x)
k = OldLinearKernel(x[1])
sum(k.c)
end
@btime Zygote.gradient($f, $x) # 15.980 μs (150 allocations: 5.89 KiB)
@btime Zygote.gradient($g, $x) # 13.853 μs (142 allocations: 5.72 KiB)
@btime Zygote.gradient($h, $x) # 4.700 μs (51 allocations: 1.89 KiB)
That's a quite noticeable regression. Do we know what exactly causes it?
For completeness I added the constructor
function OldLinearKernel(c::Real; check_args=true)
check_args && @old_check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}([c])
end
And these are the results
function h(x)
k = OldLinearKernel(x[1])
sum(k.c)
end
function i(x)
k = OldLinearKernel(x[1]; check_args=false)
sum(k.c)
end
@btime Zygote.gradient($h, $x) # 6.086 μs (65 allocations: 2.58 KiB)
@btime Zygote.gradient($i, $x) # 10.404 μs (91 allocations: 3.38 KiB)
Looking back at this, how about using CheckArgs.jl
?
Summary Similarly to Distributions.jl we do not necessarily want to check for argument correctness all the time.
Proposed changes
This adds a
check_args
keyword to all constructors that require it to deactivate the check for the correctness for the arguments. Additionally this drops 1.3 for 1.6 because I am lazy to writekwarg=kwarg
(appears in 1.5 or 1.6 can't remember0What alternatives have you considered? We could define a reparametrization of the parameters but we already discussed that it should be the duty of the user to do this.
Breaking changes None! Crazy right?