Hi, here we go again. I tried to resolve the previous type instability issue by forcing the return type. This, however, results in a build error. I have a nice MWE for this one.
using Enzyme, Optimisers, Functors, Distributions, LinearAlgebra, SimpleUnPack
struct MvLocationScale{
S, D <: ContinuousDistribution, L, E
} <: ContinuousMultivariateDistribution
location ::L
scale ::S
dist ::D
scale_eps::E
end
Base.length(q::MvLocationScale) = length(q.location)
@functor MvLocationScale (location, scale)
struct RestructureMeanField{S <: Diagonal, D, L}
model::MvLocationScale{S, D, L}
end
function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
end
function Optimisers.destructure(
q::MvLocationScale{<:Diagonal, D, L}
) where {D, L}
@unpack location, scale, dist = q
flat = vcat(location, diag(scale))
flat, RestructureMeanField(q)
end
restructure_ad_forward(re, params) = re(params)::typeof(re.model)
function f(params, aux)
@unpack restructure = aux
q = restructure_ad_forward(restructure, params)
sum(q.location)
end
function main()
d = 10
m = zeros(d)
C = Diagonal(ones(d))
q = MvLocationScale(m, C, Normal(), 1e-5)
params, re = Optimisers.destructure(q)
aux = (
restructure = re,
)
display(f(params, aux))
x = ones(length(params))
∇x = zeros(length(params))
Enzyme.API.runtimeActivity!(true)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(x, ∇x), Enzyme.Const(aux))
∇x
end
Hi, here we go again. I tried to resolve the previous type instability issue by forcing the return type. This, however, results in a build error. I have a nice MWE for this one.
The error is as follows: