Closed mhauru closed 1 day ago
Ah yeah BigFloat's aren't presently expected to work. It's on our to do list to add support for, but honestly it's not high priority at the moment.
There shouldn't be any big floats coming into play in the MWE though, it's all multivariate normals with Float64s.
The problem seems to come from the same or related type instability as #1608 and https://github.com/TuringLang/Turing.jl/issues/2240. The TV
variable in the above function has a type of Matrix{Real}
, even though it should be inferrable to Matrix{Float64}
, and if you fix it to Matrix{Float64}
the issue goes away.
Hm it definitely thinks there’s a code path that could call a big float — even if practically it’s not used.
If you can minimize this a bit more I can work on making sure this error doesn’t happen
MWE that only depends on Accessors and Distributions:
module MWE
import Accessors
import Distributions
using Enzyme
#Enzyme.API.runtimeActivity!(true)
struct VarName{sym,T}
optic::T
function VarName{sym}(optic=identity) where {sym}
return new{sym,typeof(optic)}(optic)
end
end
function Base.:(==)(x::VarName{symx}, y::VarName{symy}) where {symx,symy}
return x.optic == y.optic && symx == symy
end
struct TypeWrap{T} end
struct VarInfo{Tval,Tlogp}
vals::Tval
logp::Base.RefValue{Tlogp}
end
function getindex(vi::VarInfo, vn::VarName)
range = vn == x_varname1 ? (1:2) : (3:4)
return copy(vi.vals[range])
end
VarInfo(old_vi::VarInfo, x) = VarInfo(x, Base.RefValue{eltype(x)}(old_vi.logp[]))
function tilde_assume!!(right, vn, vi)
y = [1.0, 1.0]
_ = Distributions.logpdf(right, y)
r = getindex(vi, vn)
logp = 0.0
vi.logp[] += logp
return r, vi
end
x_varname1 = VarName{:x}((Accessors.@optic _[:, 1]))
x_varname2 = VarName{:x}((Accessors.@optic _[:, 2]))
function satellite_model_matrix(__varinfo__, ::(TypeWrap){TV}) where {TV}
P0 = vcat([0.1 0.0], [0.0 0.1])
x = TV(undef, 2, 2)
v1 = Distributions.MvNormal([0.0, 0.0], P0)
v3, __varinfo__ = tilde_assume!!(v1, x_varname1, __varinfo__)
x[:, 1] .= v3
v1 = Distributions.MvNormal(x[:, 1], P0)
v4, __varinfo__ = tilde_assume!!(v1, x_varname2, __varinfo__)
x[:, 2] .= v4
return nothing, __varinfo__
end
vi = VarInfo(
[0.0, 0.0, 0.0, 0.0],
Base.RefValue{Float64}(0.0),
)
function g(x)
vi_new = VarInfo(vi, x)
_, wrapper_new = satellite_model_matrix(vi_new, TypeWrap{Matrix{Real}}())
return wrapper_new.logp[]
end
x = [1.0, 1.0, 1.0, 1.0]
Enzyme.autodiff(ReverseWithPrimal, g, Active, Enzyme.Duplicated(x, zero(x)))
# using ForwardDiff
# @show ForwardDiff.gradient(g, x)
end
Output:
On 0.0.132 and latest Enzyme.jl main.
I appreciate that's not very minimal, but I'm done for the day and wanted to put this up here. I can try to minimise further tomorrow if that's helpful.
This started life as a reproduction of #1608, but with the latest update to Enzyme stopped producing the error in #1608 and instead started spitting out the above.