EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Reverse mode apply iterate #1485

Closed wsmoses closed 3 weeks ago

wsmoses commented 1 month ago
wmoses@beast:~/git/Enzyme.jl ((HEAD detached at origin/reviterate)) $ cat batch.jl 
using Enzyme

Enzyme.API.printall!(true)

concat() = ()
concat(a) = a
concat(a, b) = (a..., b...)
concat(a, b, c...) = concat(concat(a, b), c...)

function make_byref(out, x)
    res = 0.0
    x = Base.inferencebarrier(@inbounds (x[1][1],x[1][2]))
    for v in x
        v = v::Float64
        res += v*v
    end
    out[] = res
    nothing
end

function make_byref2(out, x)
    res = 0.0
    x = Base.inferencebarrier(@inbounds (x[1][1],x[1][2]))
    tup = iterate(x)
    if tup !== nothing
        res += tup[1]::Float64
    end
    out[] = res
    nothing
end

    x = [(2.0, 3.0), (7.9, 11.2)]
    dx = [(0.0, 0.0), (0.0, 0.0)]
    dx2 = [(0.0, 0.0), (0.0, 0.0)]
    out = Ref(0.0)
    dout = Ref(1.0)
    dout2 = Ref(3.0)
    res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), BatchDuplicated(x, (dx, dx2)))

@gbaraldi

github-actions[bot] commented 1 month ago

Benchmark Results

main 18c6fa18c064e6... main/18c6fa18c064e6...
basics/overhead 4.03 ± 0.001 ns 4.34 ± 0.01 ns 0.929
time_to_load 0.35 ± 0.0015 s 0.376 ± 0.0013 s 0.932

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

wsmoses commented 1 month ago

@vchuravy per signal convo, my best memory of where this left off (@gbaraldi please correct me).

We find a segfault. Specifically we are apparently setting the name of a datatype to something new which is wrong (specifically tuple.name or something).

Backtracing to the culprit we find that it is getting overwritten in the += of the reverse pass. Specifically we have a for loop of apply generics (for the iterate) and also an unstable get nth index of. These result in an allocation of shadow pointers from their results for use in the reverse pass. We shuold have a cache of Tuple{Ref{Float64},Ref{Float64}}. The tuple of size 2 is from the batch. The ref as that is what the shadow from the augmented primal of unstable get nth index of should return.

In the reverse pass, we should have two iterations -- just like the fwd pass has two iterations. It is on the last iteration of the reverse pass [aka the theoretical index 0] which is the issue. The cache does not actually contain 2 (for the two iters of loop) Tuple{Ref{Float64},Ref{Float64}}'s. Instead it contains as first element (1=Tuple.name, 2=Tuple.name) and then (1=RefValue{1.0}, 2=RefValue{1.0}) [random float values added, idr what they were]. We had earlier spent a rabbit hole looking at the assembly of the indexing since LLVM appears to have split the induction variable into two counters, one which is negative and increments up and the other which goes down. One is used for indexing, the other for the revere pass loop exit. This was originally suspicious since the indexing one on the last reverse pass iteration (aka i==0) seemed to index at an offset of -8 possibly implying the loop was doing an incorrect iteration. However this negation was actually expected as it basically was LLVM strength reduction moving part of the indexing out of the loop.

So the fundamental issue here is why is the 0th element of the cache busted. So far, no clue.

It is possible this could be simplified a bit from julia end by basically turning it into explicit iterate calls and then doing some more precise type stability -- but maybe not.

For obvious reasons of unknown loop size we use the exponential allocation which makes things more obnoxious to look at per the generated junk IR. However, it may be a bug in the julia side of the expontential copy? Just a guess though as we ended very very stuck.

vchuravy commented 1 month ago

We crash due a data corruption:

$2 = (jl_datatype_t *) 0x7ac98c0950d0
(rr) p jdt->name
$3 = (jl_typename_t *) 0x7ac9a1df6ae0 <jl_system_image_data+72308000>
(rr) p jdt->name->name
$4 = (jl_sym_t *) 0x4030000000000000
(rr) p jdt->name->name

Reverse executing

Thread 1 hit Hardware watchpoint 1: *(jl_sym_t **) 0x7ac9a1df6ae0

Old value = (jl_sym_t *) 0x4030000000000000
New value = (jl_sym_t *) 0x4010000000000000
0x00007ac9b308e9a5 in * () at float.jl:411
warning: 411    float.jl: No such file or directory
(rr) bt
#0  0x00007ac9b308e9a5 in * () at float.jl:411
#1  julia_make_byref_1412 (out=..., x=<optimized out>)
    at /home/vchuravy/src/Enzyme/rev_apply_iterate.jl:14
#2  0x00007ac9b308e9a5 in diffe2julia_make_byref_1412wrap ()
#3  0x00007ac9b30990d8 in macro expansion ()
    at /home/vchuravy/src/Enzyme/src/compiler.jl:5916
#4  enzyme_call () at /home/vchuravy/src/Enzyme/src/compiler.jl:5566

Second overwrite:

Thread 1 hit Hardware watchpoint 1: *(jl_sym_t **) 0x7ac9a1df6ae0

Old value = (jl_sym_t *) 0x4010000000000000
New value = (jl_sym_t *) 0x7ac9ab47c1d8
0x00007ac9b308e997 in * () at float.jl:411
411 in float.jl

Old value = (jl_sym_t *) 0x4010000000000000
New value = (jl_sym_t *) 0x7ac9ab47c1d8
0x00007ac9b308e997 in * () at float.jl:411
411 in float.jl
(rr) bt 4
#0  0x00007ac9b308e997 in * () at float.jl:411
#1  julia_make_byref_1412 (out=..., x=<optimized out>)
    at /home/vchuravy/src/Enzyme/rev_apply_iterate.jl:14
#2  0x00007ac9b308e997 in diffe2julia_make_byref_1412wrap ()
#3  0x00007ac9b30990d8 in macro expansion ()
    at /home/vchuravy/src/Enzyme/src/compiler.jl:5916
(More stack frames follow...)

Since this memory is allocated in the system image (it's the name of Tuple) it smells more like a calling mistake.

(rr) p jl_(jt)
Tuple{Base.RefValue{DataType}}
$1 = void
vchuravy commented 1 month ago

Instead it contains as first element (1=Tuple.name, 2=Tuple.name)

So how did these values end up in the cache?

vchuravy commented 1 month ago

So I find:

#12 0x00007ac9b444a212 in ijl_apply_generic (
    F=0x7ac98ec33350 <jl_system_image_data+1456912>, args=0x7fffbe9ae1d8,
    nargs=5) at /home/vchuravy/src/julia-1.10/src/gf.c:3077
3077        return _jl_invoke(F, args, nargs, mfunc, world);
(rr) p jl_(F)
Enzyme.Compiler.idx_jl_getfield_rev
$7 = void
(rr) p jl_(args[0])
Base.RefValue{DataType}(x=Tuple{Float64, Int64})
$8 = void
(rr) p jl_(args[1])
(1=Any, 2=Any)
$9 = void
(rr) p jl_(args[2])
Base.Val{1}
$10 = void
(rr) p jl_(args[3])
Base.Val{false}()
$11 = void
(rr) p jl_(args[4])
Base.RefValue{DataType}(x=Tuple{Float64, Int64})
$12 = void

A bit fishy. The code got changed here but args[0] should be a Val

wsmoses commented 1 month ago

@vchuravy I don't think so?

The calling conv of that is julia function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst}

which dptr (and dptrs for extra batches) are the shadows, not vals?

wsmoses commented 1 month ago

   if !is_constant_value(gutils, ops[1])
        inp = invert_pointer(gutils, ops[1], B)
        inp = lookup_value(gutils, inp, B)
        if width == 1
            inps = [inp]
        else
            inps = LLVM.Value[]
            for w in 1:width
                push!(inps, extract_value!(B, inp, w-1))
            end
        end
    else
        inp = new_from_original(gutils, ops[1])
        inp = lookup_value(gutils, inp, B)
        inps = [inp]
    end

    vals = LLVM.Value[]
    push!(vals, inps[1])

    push!(vals, tape)

    sym = new_from_original(gutils, ops[2])
    sym = lookup_value(gutils, sym, B)
    sym = (sizeof(Int) == sizeof(Int64) ? emit_box_int64! : emit_box_int32!)(B, sym)
    sym = emit_apply_type!(B, Base.Val, [sym])
    push!(vals, sym)

    push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig))))

    for v in inps[2:end]
        push!(vals, v)
    end

    pushfirst!(vals, unsafe_to_llvm(idx_jl_getfield_rev))

so the unction, original shadow, tape [aka dret], type{val{sym}}, batched other shadows [from 2 ... n]

the shadow tape does look ishy tho.

wsmoses commented 1 month ago

@vchuravy okay the caching mechanism is fine, the function is actually returning garbage.

"Result of call to idx_jl_getfield_aug"
(1=Tuple.name, 2=Tuple.name)
Tuple.name
Tuple.name
"Val(AnyArray)="
wsmoses commented 1 month ago

The input to idx_jl_getfield_aug is RefValue{DataType}(Tuple{Float64, Int64})

codecov-commenter commented 4 weeks ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 96.35%. Comparing base (ad7694e) to head (9499ee4).

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1485 +/- ## =========================================== + Coverage 71.15% 96.35% +25.19% =========================================== Files 30 9 -21 Lines 11224 411 -10813 =========================================== - Hits 7986 396 -7590 + Misses 3238 15 -3223 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.