tpapp / TransformVariables.jl

Transformations to contrained variables from ℝⁿ.
Other
66 stars 14 forks source link

long as(::NamedTuple) fails with Enzyme #119

Open tpapp opened 8 months ago

tpapp commented 8 months ago

This is https://github.com/EnzymeAD/Enzyme.jl/issues/1235 in the wild.

using TransformVariables, Enzyme, StaticArrays

K = 7
N_EE = N_UU = N_UE = N_EU = 20
trans = as((# common
            ω_intercept = as(SVector{K}), ω_std = as(SVector{K}, asℝ₊),
            ω_corr_factor = as(SVector{K}), ζ_std = asℝ₊, ε_std = asℝ₊,
            κ = as(Real, 0.5, 1.5),
            B1 = as(SVector{3}), B2 = as(SVector{3}), BC = as(SVector{3}),
            # EE
            α̂1_EE = as(view, N_EE), α̂2_EE = as(view, N_EE),
            β̂1_EE = as(view, N_EE), β̂2_EE = as(view, N_EE),
            M̂_EE = as(view, N_EE),
            # EU
            α̂1_EU = as(view, N_EU), α̂2_EU = as(view, N_EU),
            β̂1_EU = as(view, N_EU), β̂2_EU = as(view, N_EU),
            M̂_EU = as(view, N_EU), ŵ2_EU = as(view, N_EU),
            # UE
            α̂1_UE = as(view, N_UE), α̂2_UE = as(view, N_UE),
            β̂1_UE = as(view, N_UE), β̂2_UE = as(view, N_UE),
            M̂_UE = as(view, N_UE), ŵ1_UE = as(view, N_UE),
            # UU
            α̂1_UU = as(view, N_UU), α̂2_UU = as(view, N_UU),
            β̂1_UU = as(view, N_UU), β̂2_UU = as(view, N_UU),
            M̂_UU = as(view, N_UU),
            ŵ1_UU = as(view, N_UU), ŵ2_UU = as(view, N_UU),
        ))

_s(x::Real) = x                 # simple recursive sum, for testing
_s(x::AbstractArray) = sum(_s, x)
_s(x::NamedTuple) = sum(_s, values(x))

g(t, x) = _s(transform(t, x))
x = zeros(dimension(trans))
g(trans, x)                     # sanity check that primal call works
∂ℓ_∂x = zero(x)
_, y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, g,
                       Enzyme.Active, Enzyme.Const(trans), Enzyme.Duplicated(x, ∂ℓ_∂x))

fails with

ERROR: AssertionError: Found unhandled active variable in tuple splat, jl_apply_iterate @NamedTuple{ω_intercept::TransformVariables.StaticArrayTransformation{7, Tuple{7}, TransformVariables.Identity}, ω_std::TransformVariables.StaticArrayTransformation{7, Tuple{7}, TransformVariables.ShiftedExp{true, Int64}}, ω_corr_factor::TransformVariables.StaticArrayTransformation{7, Tuple{7}, TransformVariables.Identity}, ζ_std::TransformVariables.ShiftedExp{true, Int64}, ε_std::TransformVariables.ShiftedExp{true, Int64}, κ::TransformVariables.ScaledShiftedLogistic{Float64}, B1::TransformVariables.StaticArrayTransformation{3, Tuple{3}, TransformVariables.Identity}, B2::TransformVariables.StaticArrayTransformation{3, Tuple{3}, TransformVariables.Identity}, BC::TransformVariables.StaticArrayTransformation{3, Tuple{3}, TransformVariables.Identity}, α̂1_EE::TransformVariables.ViewTransformation{1}, α̂2_EE::TransformVariables.ViewTransformation{1}, β̂1_EE::TransformVariables.ViewTransformation{1}, β̂2_EE::TransformVariables.ViewTransformation{1}, M̂_EE::TransformVariables.ViewTransformation{1}, α̂1_EU::TransformVariables.ViewTransformation{1}, α̂2_EU::TransformVariables.ViewTransformation{1}, β̂1_EU::TransformVariables.ViewTransformation{1}, β̂2_EU::TransformVariables.ViewTransformation{1}, M̂_EU::TransformVariables.ViewTransformation{1}, ŵ2_EU::TransformVariables.ViewTransformation{1}, α̂1_UE::TransformVariables.ViewTransformation{1}, α̂2_UE::TransformVariables.ViewTransformation{1}, β̂1_UE::TransformVariables.ViewTransformation{1}, β̂2_UE::TransformVariables.ViewTransformation{1}, M̂_UE::TransformVariables.ViewTransformation{1}, ŵ1_UE::TransformVariables.ViewTransformation{1}, α̂1_UU::TransformVariables.ViewTransformation{1}, α̂2_UU::TransformVariables.ViewTransformation{1}, β̂1_UU::TransformVariables.ViewTransformation{1}, β̂2_UU::TransformVariables.ViewTransformation{1}, M̂_UU::TransformVariables.ViewTransformation{1}, ŵ1_UU::TransformVariables.ViewTransformation{1}, ŵ2_UU::TransformVariables.ViewTransformation{1}}
Stacktrace:
  [1] error_if_active_iter(arg::Base.RefValue{@NamedTuple{…}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/rules/jitrules.jl:775
  [2] Tuple
    @ ./namedtuple.jl:200 [inlined]
  [3] values
    @ ./namedtuple.jl:379 [inlined]
  [4] transform_with
    @ ~/code/julia/TransformVariables/src/aggregation.jl:388 [inlined]
  [5] transform
    @ ~/code/julia/TransformVariables/src/generic.jl:268
  [6] g
    @ ./REPL[31]:1 [inlined]
  [7] g
    @ ./REPL[31]:0 [inlined]
  [8] augmented_julia_g_3346_inner_1wrap
    @ ./REPL[31]:0
  [9] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:5299 [inlined]
 [10] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Duplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:4977
 [11] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:4930
 [12] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Vararg{…})
    @ Enzyme ~/.julia/packages/Enzyme/Dd2LU/src/Enzyme.jl:198
 [13] autodiff(::ReverseMode{true, FFIABI}, ::typeof(g), ::Type, ::Const{TransformVariables.TransformTuple{…}}, ::Vararg{Any})    @ Enzyme ~/.julia/packages/Enzyme/Dd2LU/src/Enzyme.jl:224
 [14] top-level scope
    @ REPL[35]:1