EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
428 stars 59 forks source link

Error of custom type interaction with StaticArrays (type unstable update of an immutable variable) #1263

Open just-walk opened 5 months ago

just-walk commented 5 months ago

Enzyme fails when using a custom type that is based on SMatrix and abstract type labels. The error is discussed here, where @wsmoses suggested that it is "not-yet-implemented" with a workaround to type-stabilize the code.

MWE and its output are below:

using Enzyme
using Random 
using StaticArrays

abstract type AbstractBasisType end

struct Contravariant <: AbstractBasisType end

struct CurvilinearBasisVectors{N, T, C, B, V <: AbstractBasisType} <: StaticMatrix{N, N, T}
    __x::Union{SMatrix{N, N, T}}
    function CurvilinearBasisVectors{N, T, C, B, V}(b::AbstractMatrix) where {N, T, C, B, V}
        return new{N, T, C, B, V}(SMatrix{N, N, T}(b))
    end
end

Base.@propagate_inbounds function Base.getindex(v::CurvilinearBasisVectors{N, T, C, B, V}, i::Int) where {N, T, C, B, V}
    return view(getfield(v, :__x), i)[]
end

basis_labels = (:∇x, :∇y, :∇z);
coord_labels = (:x, :y, :z);

a = CurvilinearBasisVectors{3,Float64,basis_labels,coord_labels,Contravariant}(
    rand(3,3),
);
da = CurvilinearBasisVectors{3,Float64,basis_labels,coord_labels,Contravariant}(
    zeros(3,3),
);

@show a,da

function f(x, y)
    sum(sum(y .* x))
end

@show f(a, 2.0)

@show autodiff(Reverse, f, Active, Duplicated(a, da), Active(2.0))
@show da
(a, da) = ([0.03087747827664955 0.2528585283866567 0.2994280949524749; 0.3884839505154861 0.6683809551000912 0.9092201295064689; 0.564563029791595 0.06953550568414368 0.8692541334330144], [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0])
f(a, 2.0) = 8.105203611293161
ERROR: LoadError: setfield!: immutable struct of type SArray cannot be changed
Stacktrace:
  [1] rt_jl_getfield_rev(::SMatrix{3, 3, Float64, 9}, ::Base.RefValue{NTuple{9, Float64}}, ::Type{Val{:data}}, ::Val{false})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/rules/typeunstablerules.jl:257
  [2] getindex
    @ ~/.julia/packages/StaticArrays/eGKzB/src/SArray.jl:62 [inlined]
  [3] view
    @ ~/.julia/packages/StaticArrays/eGKzB/src/abstractarray.jl:291 [inlined]
  [4] getindex
    @ ~/software/julia/enzyme/test-basis.jl:17 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:135 [inlined]
  [6] __broadcast
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:123 [inlined]
  [7] _broadcast
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:119 [inlined]
  [8] copy
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:60 [inlined]
  [9] materialize
    @ ./broadcast.jl:903 [inlined]
 [10] f
    @ ~/software/julia/enzyme/test-basis.jl:33 [inlined]
 [11] f
    @ ~/software/julia/enzyme/test-basis.jl:0 [inlined]
 [12] diffejulia_f_5401_inner_1wrap
    @ ~/software/julia/enzyme/test-basis.jl:0
 [13] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:5306 [inlined]
 [14] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Active{…}, ::Float64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4984
 [15] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4926
 [16] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:215
 [17] autodiff(::ReverseMode{false, FFIABI}, ::typeof(f), ::Type, ::Duplicated{CurvilinearBasisVectors{…}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:224
 [18] macro expansion
    @ show.jl:1181 [inlined]
 [19] top-level scope
    @ ~/software/julia/enzyme/test-basis.jl:38
 [20] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [21] top-level scope
    @ REPL[4]:1
wsmoses commented 5 months ago

Marking this a duplicate of https://github.com/EnzymeAD/Enzyme.jl/issues/970 [tho this is an easier MWE]