Enzyme fails when using a custom type that is based on SMatrix and abstract type labels. The error is discussed here, where @wsmoses suggested that it is "not-yet-implemented" with a workaround to type-stabilize the code.
MWE and its output are below:
using Enzyme
using Random
using StaticArrays
abstract type AbstractBasisType end
struct Contravariant <: AbstractBasisType end
struct CurvilinearBasisVectors{N, T, C, B, V <: AbstractBasisType} <: StaticMatrix{N, N, T}
__x::Union{SMatrix{N, N, T}}
function CurvilinearBasisVectors{N, T, C, B, V}(b::AbstractMatrix) where {N, T, C, B, V}
return new{N, T, C, B, V}(SMatrix{N, N, T}(b))
end
end
Base.@propagate_inbounds function Base.getindex(v::CurvilinearBasisVectors{N, T, C, B, V}, i::Int) where {N, T, C, B, V}
return view(getfield(v, :__x), i)[]
end
basis_labels = (:∇x, :∇y, :∇z);
coord_labels = (:x, :y, :z);
a = CurvilinearBasisVectors{3,Float64,basis_labels,coord_labels,Contravariant}(
rand(3,3),
);
da = CurvilinearBasisVectors{3,Float64,basis_labels,coord_labels,Contravariant}(
zeros(3,3),
);
@show a,da
function f(x, y)
sum(sum(y .* x))
end
@show f(a, 2.0)
@show autodiff(Reverse, f, Active, Duplicated(a, da), Active(2.0))
@show da
Enzyme fails when using a custom type that is based on
SMatrix
andabstract type
labels. The error is discussed here, where @wsmoses suggested that it is "not-yet-implemented" with a workaround to type-stabilize the code.MWE and its output are below: