Closed mhauru closed 1 month ago
@mhauru having a MWE for this as well would be quite helpful
using Enzyme: Enzyme
using LinearAlgebra: Diagonal, transpose
function invwsumsq(w::AbstractVector, a::AbstractVector)
s = zero(zero(eltype(a)) / zero(eltype(w)))
for i in eachindex(w)
s += abs2(a[i]) / w[i]
end
return s
end
_logpdf(d, x) = invwsumsq(d.Σ.diag, x .- d.μ)
function demo_func(x::Any=transpose([1.5 2.0;]);)
m = [-0.30725218207431315, 0.5492115788562757]
d = (; Σ = Diagonal([1.0, 1.0]), μ = m)
logp = _logpdf(d, reshape(x, (2,)))
return logp
end
f(x) = demo_func()
x = [0.0, 0.0]
Enzyme.autodiff(
Enzyme.Reverse,
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(x, zero(x)),
)
There you go.
MWE:
Output