Closed Red-Portal closed 3 weeks ago
using Enzyme
using Distributions
using LinearAlgebra
using StableRNGs
using Random
using Test
function ad_forward(scale_diag, rng)
T = eltype(scale_diag)
samples = scale_diag.*rand(rng, Normal{T}(zero(T), one(T)), 10, 10)
res = mean(sum, eachcol(samples))
#res = mean(Base.Fix1(logdensity, Problem()), eachcol(samples))
return res
end
function mwe(T)
d = 10
seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)
q = ones(T, d)
∇x = make_zero(q)
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
ad_forward,
Enzyme.Active,
Enzyme.Duplicated(q, ∇x),
Enzyme.Const(rng)
)
∇x
end
@show mwe(Float32)
@show mwe(Float64)
@test mwe(Float32) ≈ mwe(Float64) rtol=1e-2
wmoses@beast:~/git/Enzyme.jl ((HEAD detached at origin/main)) $ cat vi.jl
using Enzyme
using Test
using Statistics
Enzyme.API.printall!(true)
function ad_forward(scale_diag::Vector{T}) where T
t1 = T[T(1.41)]
samples = Vector{T}(undef, 1)
sd = @inbounds scale_diag[1]
@inbounds samples[1] = @inbounds t1[1] * sd
res = mean(sum, eachcol(samples))
return res
end
function mwe(T)
q = ones(T, 1)
∇x = make_zero(q)
res, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
ad_forward,
Enzyme.Active,
Enzyme.Duplicated(q, ∇x),
)
∇x
end
@show mwe(Float32)
@show mwe(Float64)
@test mwe(Float32) ≈ mwe(Float64) rtol=1e-2
using Enzyme
using Test
using Statistics
Enzyme.API.printall!(true)
Enzyme.Compiler.DumpPostOpt[] = true
@noinline function mysum(A)
m = eltype(A)(0)
for a in A
m += a
end
return m
end
_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)
unstable = nothing
@inline function mymean(f, A, ::Type{T}, c) where T
c && return unstable
x1 = f(@inbounds A[1]) / 1
result = Base.mapfoldl_impl(f, Base.add_sum, Base._InitialValue(), A)
return result
end
function sum2(x::Array{T}) where T
return Base.foldl_impl(Base.add_sum, T(0), x)
#return Base._foldl_impl(Base.add_sum, T(0), x)
end
function ad_forward(scale_diag::Vector{T}, c) where T
ccall(:jl_, Cvoid, (Any,), scale_diag)
res = mymean(sum2, [scale_diag,], T, c)
return res
end
function mwe(T)
q = ones(T, 1)
∇x = make_zero(q)
res, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
ad_forward,
Enzyme.Active,
Enzyme.Duplicated(q, ∇x),
Enzyme.Const(false),
)
∇x
end
@show mwe(Float32)
#@show mwe(Float64)
#@test mwe(Float32) ≈ mwe(Float64) rtol=1e-2
Hi,
It seems that the gradient through
Base.Fix1
are ignored ifFloat32
is involved. Here's an MWE:this yields:
On the other hand, removing
Base.Fix1
by doing the following works: