EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Write analysis pass on `autodiff` vs `autodiff_thunk` #1854

Closed gdalle closed 4 days ago

gdalle commented 4 days ago

Recently (#1669) an analysis pass was implemented for autodiff so that for non-annotated functions f, any attempt to write into f would throw an error. This is controlled by the type parameter ErrIfFuncWritten of the mode.

I'm wondering if we can get the same behavior from autodiff_thunk by default. The issue is that, at the moment, autodiff_thunk does not accept functions without annotations at all.

gdalle commented 4 days ago

MWE:

julia> using Enzyme

julia> struct MyClosure{D}
           data::D
       end

julia> function (f::MyClosure)(x)
           copyto!(f.data, x)
           return sum(f.data)
       end

julia> f = MyClosure([0.0])
MyClosure{Vector{Float64}}([0.0])

julia> autodiff(Reverse, f, Active, Duplicated([1.0], [0.0]))
ERROR: Function argument passed to autodiff cannot be proven readonly.
If the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)
See https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.
The potentially writing call is   call void @llvm.memmove.p0i8.p0i8.i64(i8* nonnull align 1 %arrayptr, i8* nonnull align 1 %arrayptr10, i64 %31, i1 noundef false), !dbg !121, !noalias !126, using   %arrayptr = load i8*, i8** %34, align 8, !dbg !94, !tbaa !124, !alias.scope !63, !noalias !64, !nonnull !0

Stacktrace:
  [1] memmove
    @ ./cmem.jl:26 [inlined]
  [2] unsafe_copyto!
    @ ./array.jl:337 [inlined]
  [3] _copyto_impl!
    @ ./array.jl:376
  [4] copyto!
    @ ./array.jl:368 [inlined]
  [5] copyto!
    @ ./array.jl:388 [inlined]
  [6] MyClosure
    @ ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:8 [inlined]
  [7] MyClosure
    @ ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:0 [inlined]
  [8] diffejulia_MyClosure_4749_inner_1wrap
    @ ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:0
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
 [10] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
 [11] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6671 [inlined]
 [12] autodiff
    @ ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:320 [inlined]
 [13] autodiff(mode::ReverseMode{…}, f::MyClosure{…}, ::Type{…}, args::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:332
 [14] top-level scope
    @ ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:14
Some type information was truncated. Use `show(err)` to see complete types.

julia> autodiff_thunk(ReverseSplitNoPrimal, typeof(sum), Active, Duplicated{Float64})
ERROR: MethodError: no method matching autodiff_thunk(::EnzymeCore.ReverseModeSplit{…}, ::Type{…}, ::Type{…}, ::Type{…})

Closest candidates are:
  autodiff_thunk(::EnzymeCore.ReverseModeSplit{ReturnPrimal, ReturnShadow, Width, ModifiedBetweenT, RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, ::Type{<:Annotation}...) where {FA<:Annotation, A<:Annotation, ReturnPrimal, ReturnShadow, Width, ModifiedBetweenT, RABI<:EnzymeCore.ABI, Nargs, ErrIfFuncWritten}
   @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:615
  autodiff_thunk(::ForwardMode{RABI, ErrIfFuncWritten}, ::Type{FA}, ::Type{A}, ::Type{<:Annotation}...) where {FA<:Annotation, A<:Annotation, RABI<:EnzymeCore.ABI, Nargs, ErrIfFuncWritten}
   @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:690

Stacktrace:
 [1] top-level scope
   @ ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:16
Some type information was truncated. Use `show(err)` to see complete types.
wsmoses commented 4 days ago

Can you just use the error setters in the mode?

autodiff_thunk is a lower level API so having extremely strict checking on all args having activity is desired

gdalle commented 4 days ago

Sure, I just wanted to make sure that this workaround was needed. I'll make these setters public and then use them