EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

Pseudo-zero initialize non-differential values for sret #347

Closed SBuercklin closed 2 years ago

SBuercklin commented 2 years ago

AD with respect to a struct that has a non-differentiable field fails

struct Foo2{X,Y}
    x::X
    y::Y
end

test_f(f::Foo2) = f.x^2

julia> autodiff(test_f, Active(Foo2(3.0, 2.0)))
(Foo2{Float64, Float64}(6.0, 0.0),)

julia> autodiff(test_f, Active(Foo2(3.0, :two)))
ERROR: UndefRefError: access to undefined reference
Stacktrace:
 [1] getproperty
   @ ./Base.jl:42 [inlined]
 [2] getindex
   @ ./refvalue.jl:56 [inlined]
 [3] macro expansion
   @ ~/.julia/packages/Enzyme/fgrLi/src/compiler.jl:3641 [inlined]
 [4] enzyme_call
   @ ~/.julia/packages/Enzyme/fgrLi/src/compiler.jl:3434 [inlined]
 [5] CombinedAdjointThunk
   @ ~/.julia/packages/Enzyme/fgrLi/src/compiler.jl:3416 [inlined]
 [6] autodiff
   @ ~/.julia/packages/Enzyme/fgrLi/src/Enzyme.jl:202 [inlined]
 [7] autodiff(f::typeof(test_f), args::Active{Foo2{Float64, Symbol}})
   @ Enzyme ~/.julia/packages/Enzyme/fgrLi/src/Enzyme.jl:227
 [8] top-level scope
   @ REPL[12]:1

This is on 1.7.3 with both Enzyme@0.9.6 and the current Enzyme#main

vchuravy commented 2 years ago

So this works with forward-mode, but not reverse-mode :/ The issue is that we want to zero initialize this struct which is not valid due to the symbol (at least that's my guess).

julia> autodiff(Forward, test_f, Duplicated(Foo2(3.0, :two), Foo2(1.0, :two)))
(6.0,)
wsmoses commented 2 years ago

@vchuravy that is indeed the cause. Specifically, what happens from looking at the IR (below) is as follows:

res = { 0.0, null }
res[1] += 2 * f.x

And since null as a symbol creates the segfault. Since the symbol isn't initialized in the function, there's no shadow store that Enzyme would also use here. There are several potential remedies, but I think the immediate one is to say that this is improper as a usage of Active in contrast to Duplicated.

define internal { { double, {} addrspace(10)* } } @diffejulia_test_f_1624_.1({ double, {} addrspace(10)* } %0, double %differeturn) local_unnamed_addr #2 !dbg !22 {
entry:
  %"'de" = alloca double, align 8
  store double 0.000000e+00, double* %"'de", align 8
  %".fca.0.extract'de" = alloca double, align 8
  store double 0.000000e+00, double* %".fca.0.extract'de", align 8
  %"'de2" = alloca { double, {} addrspace(10)* }, align 8
  store { double, {} addrspace(10)* } zeroinitializer, { double, {} addrspace(10)* }* %"'de2", align 8
  %.fca.0.extract = extractvalue { double, {} addrspace(10)* } %0, 0, !dbg !23
  br label %invertentry, !dbg !23

invertentry:                                      ; preds = %entry
  store double %differeturn, double* %"'de", align 8
  %1 = load double, double* %"'de", align 8
  %m0diffe.fca.0.extract = fmul fast double %1, %.fca.0.extract
  %m1diffe.fca.0.extract = fmul fast double %1, %.fca.0.extract
  store double 0.000000e+00, double* %"'de", align 8
  %2 = load double, double* %".fca.0.extract'de", align 8
  %3 = fadd fast double %2, %m0diffe.fca.0.extract
  store double %3, double* %".fca.0.extract'de", align 8
  %4 = load double, double* %".fca.0.extract'de", align 8
  %5 = fadd fast double %4, %m1diffe.fca.0.extract
  store double %5, double* %".fca.0.extract'de", align 8
  %6 = load double, double* %".fca.0.extract'de", align 8
  %7 = getelementptr inbounds { double, {} addrspace(10)* }, { double, {} addrspace(10)* }* %"'de2", i32 0, i32 0
  %8 = load double, double* %7, align 8
  %9 = fadd fast double %8, %6
  store double %9, double* %7, align 8
  store double 0.000000e+00, double* %".fca.0.extract'de", align 8
  %10 = load { double, {} addrspace(10)* }, { double, {} addrspace(10)* }* %"'de2", align 8
  %11 = insertvalue { { double, {} addrspace(10)* } } undef, { double, {} addrspace(10)* } %10, 0
  ret { { double, {} addrspace(10)* } } %11
}
vchuravy commented 2 years ago

Yeah but we can't pass in Duplicated.

Could we call a zero function instead of doing a zero allocation?

wsmoses commented 2 years ago

I'm not sure that would be a good resolution.

Unless we tell the compiler to preserve a lot more info, this is just an arbitrary register that's getting its allocated shadow without info of what it is.

We may be able to get away with this here since technically the only one that matters here is the argument type (which produce the shadows returned), and we do have the Julia version of that.

I think it's possible to construct a case though where any other part of the 0'ing procedure ends up somehow accessible in the return, though.

wsmoses commented 2 years ago

You can always wrap something in a ref as a shim, allowing it to be duplicated.

vchuravy commented 2 years ago

Right my goal would be to allow the user to define what it means to allocate a zero. Maybe the other question is why ware we even accessing that field.

You can always wrap something in a ref as a shim, allowing it to be duplicated.

We should deal with Julia struct semantics correctly. Library writers will apply Enzyme to code without necessarily knowing what input the user provides. The split between Duplicated/Active is already problematic enough.

wsmoses commented 2 years ago

Reason is that we return it as part of the register. It must have some value. But more specifically, we always zero-initialize all shadow memory, so that's why that field was "accessed" -- or do you mean why was it accessed after AD?

I think regardless we should probably force a type error if you ever have Active{T} for any T which isn't legal to be active. (and we can emit a warning saying perhaps make it duplicated [and may need to wrap in ref])

wsmoses commented 2 years ago

@vchuravy the PR makes this now into a segfault:

julia> # 2. My function
       let
           grad_storage = similar(params0)
           Enzyme.gradient!(Enzyme.Reverse, grad_storage, objective, params0)
           @show grad_storage
       end

signal (11): Segmentation fault
in expression starting at REPL[80]:2
jl_is_type at /buildworker/worker/package_linux64/build/src/julia.h:1210 [inlined]
jl_f_issubtype at /buildworker/worker/package_linux64/build/src/builtins.c:490
unknown function (ip: 0x7f25c40e3436)
unknown function (ip: 0x7f25c40e439f)
Allocations: 158796627 (Pool: 158722478; Big: 74149); GC: 185
Segmentation fault (core dumped)
vchuravy commented 2 years ago

Which PR?

wsmoses commented 2 years ago

https://github.com/EnzymeAD/Enzyme.jl/pull/408

vchuravy commented 2 years ago

@wsmoses Do you still have the full MWE? https://github.com/EnzymeAD/Enzyme.jl/issues/347#issuecomment-1212435054 is not complete?

vchuravy commented 2 years ago

So the issue here is not that the symbol is undef, that would be fine. The issue us that the Box we use for the sret is undef.

julia> Enzyme.Compiler.Box{Enzyme.Compiler.AnonymousStruct(Tuple{Float64, Symbol})}()
Enzyme.Compiler.Box{NamedTuple{(Symbol("1"), Symbol("2")), Tuple{Float64, Symbol}}}(#undef)
vchuravy commented 2 years ago

Digging some deeper. The issue is that we check that the validity of the sret struct with Julia isdefined check.

This ends up in https://github.com/JuliaLang/julia/blob/b9b60fcde61ff18d77cb548421b3f71a369b4e02/src/datatype.c#L1789 where the value of the first pointerfield is used to make that decision.

We can't even sidestep this:

val = Base.unsafe_load(Base.unsafe_convert(Ptr{eltype(b)},  b));
Unhandled Task ERROR: UndefRefError: access to undefined reference

I think regardless we should probably force a type error if you ever have Active{T} for any T which isn't legal to be active.

That doesn't work. We could have an sret that is Any, or Real and if we don't initialize that pointer value we can't load from the sret box. In Julia semantics a field can be undef, but a value can't.

wsmoses commented 2 years ago

I mean an any and real should also be duplicated.... Essentially anything that isn't a floating point register, integer, or structure containing just those should not be active.

vchuravy commented 2 years ago

@wsmoses from our last chat the thing left to do is to "pseudo-zero" the shadow sret?