EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Enzyme doesn't work for `AdvancedVI` Part VIII: FP32 + `Base.Fix1` ignores gradients #1735

Closed Red-Portal closed 3 weeks ago

Red-Portal commented 1 month ago

Hi,

It seems that the gradient through Base.Fix1 are ignored if Float32 is involved. Here's an MWE:


using Enzyme
using Distributions
using SimpleUnPack
using LinearAlgebra
using StableRNGs
using Random
using Optimisers
using Functors
using Test

struct MvLocationScale{
    S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
    scale_eps::E
end

Functors.@functor MvLocationScale (location, scale)

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
)  where {S, D, L}
    @unpack location, scale, dist = q
    n_dims = length(location)
    scale*rand(rng, dist, n_dims, num_samples) .+ location
end

# This specialization improves AD performance of the sampling path
function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
) where {L, D}
    @unpack location, scale, dist = q
    n_dims     = length(location)
    scale_diag = diag(scale)
    scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
end

struct Problem end

function logdensity(problem, params)
    sum(params)
end

function ad_forward(params, aux)
    @unpack rng, problem, restructure = aux
    q = restructure(params)
    samples = rand(rng, q, 10)
    mean(Base.Fix1(logdensity, problem), eachcol(samples))
end

function mwe(T)
    d    = 10
    seed = (0x38bef07cf9cc549d)
    rng  = StableRNG(seed)

    q = MvLocationScale(
        zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)), T(1e-5)
    )
    params, restructure, = Optimisers.destructure(q)

    aux = (
        rng         = rng,
        problem     = Problem(),
        restructure = restructure,
    )

    Enzyme.API.runtimeActivity!(true)
    ∇x = zero(params)
    fill!(∇x, zero(eltype(∇x)))
    _, y = Enzyme.autodiff(
        Enzyme.ReverseWithPrimal,
        ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(params, ∇x),
        Enzyme.Const(aux)
    )
    ∇x
end

@test mwe(Float32) ≈ mwe(Float64) rtol=1e-2

this yields:

Test Failed at REPL[19]:1
  Expression: ≈(mwe(Float32), mwe(Float64), rtol = 0.01)
   Evaluated: Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ≈ [0.9999999999999999, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.22300920676562927] (rtol=0.01)

On the other hand, removing Base.Fix1 by doing the following works:

function logdensity(params)
    sum(params)
end

function ad_forward(params, aux)
    @unpack rng, problem, restructure = aux
    q = restructure(params)
    samples = rand(rng, q, 10)
    mean(logdensity, eachcol(samples))
end
wsmoses commented 4 weeks ago
using Enzyme

using Distributions
using LinearAlgebra
using StableRNGs
using Random
using Test

function ad_forward(scale_diag, rng)

    T = eltype(scale_diag)
    samples = scale_diag.*rand(rng, Normal{T}(zero(T), one(T)), 10, 10)

    res = mean(sum, eachcol(samples))
    #res = mean(Base.Fix1(logdensity, Problem()), eachcol(samples))
    return res
end

function mwe(T)
    d    = 10
    seed = (0x38bef07cf9cc549d)
    rng  = StableRNG(seed)

    q = ones(T, d)

    ∇x = make_zero(q)
    _, y = Enzyme.autodiff(
        Enzyme.ReverseWithPrimal,
        ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(q, ∇x),
        Enzyme.Const(rng)
    )
    ∇x
end

@show mwe(Float32)
@show mwe(Float64)
@test mwe(Float32) ≈ mwe(Float64) rtol=1e-2
wsmoses commented 4 weeks ago
wmoses@beast:~/git/Enzyme.jl ((HEAD detached at origin/main)) $ cat vi.jl 
using Enzyme
using Test
using Statistics

Enzyme.API.printall!(true)

function ad_forward(scale_diag::Vector{T}) where T 
    t1 = T[T(1.41)] 
    samples = Vector{T}(undef, 1)
    sd = @inbounds scale_diag[1]
    @inbounds samples[1] = @inbounds t1[1] * sd
    res = mean(sum, eachcol(samples))
    return res
end

function mwe(T)
    q = ones(T, 1)

    ∇x = make_zero(q)
    res, y = Enzyme.autodiff(
        Enzyme.ReverseWithPrimal,
        ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(q, ∇x),
    )
    ∇x
end

@show mwe(Float32)
@show mwe(Float64)
@test mwe(Float32) ≈ mwe(Float64) rtol=1e-2
wsmoses commented 3 weeks ago
using Enzyme
using Test
using Statistics

Enzyme.API.printall!(true)
Enzyme.Compiler.DumpPostOpt[] = true

@noinline function mysum(A)
    m = eltype(A)(0)
    for a in A
        m += a
    end
    return m
end

_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)

unstable = nothing

@inline function mymean(f, A, ::Type{T}, c) where T
    c && return unstable
    x1 = f(@inbounds A[1]) / 1
    result = Base.mapfoldl_impl(f, Base.add_sum, Base._InitialValue(), A)
    return result 
end

function sum2(x::Array{T}) where T
    return Base.foldl_impl(Base.add_sum, T(0), x)
    #return Base._foldl_impl(Base.add_sum, T(0), x)
end

function ad_forward(scale_diag::Vector{T}, c) where T 
    ccall(:jl_, Cvoid, (Any,), scale_diag) 
    res = mymean(sum2, [scale_diag,], T, c)
    return res
end

function mwe(T)
    q = ones(T, 1)

    ∇x = make_zero(q)
    res, y = Enzyme.autodiff(
        Enzyme.ReverseWithPrimal,
        ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(q, ∇x),
        Enzyme.Const(false),
    )
    ∇x
end

@show mwe(Float32)
#@show mwe(Float64)
#@test mwe(Float32) ≈ mwe(Float64) rtol=1e-2
wsmoses commented 3 weeks ago

Fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1740