EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
461 stars 66 forks source link

Small inference test #2152

Open wsmoses opened 2 days ago

wsmoses commented 2 days ago

Pasted here for ease.

What I'm using locally for some testing @vchuravy

module ReverseRules

using Enzyme
using Enzyme: EnzymeRules
using LinearAlgebra
using Test

f(x) = x^2

function f_ip(x)
   x[1] *= x[1]
   return nothing
end

import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig
using .EnzymeRules

q(x) = x^2
function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(q)}, ::Type{<:Active}, x::Active)
    tape = (Ref(2.0), Ref(3.4))
    if needs_primal(config)
        return AugmentedReturn(func.val(x.val), nothing, tape)
    else
        return AugmentedReturn(nothing, nothing, tape)
    end
end

function reverse(config::RevConfigWidth{1}, ::Const{typeof(q)}, dret::Active, tape, x::Active)
    @test tape[1][] == 2.0
    @test tape[2][] == 3.4
    if needs_primal(config)
        return (10+2*x.val*dret.val,)
    else
        return (100+2*x.val*dret.val,)
    end
end

using Profile
@profile Enzyme.autodiff(Enzyme.Reverse, q, Active(2.0))
using PProf
pprof(; webhost="<hostname>")
while true

end

using SnoopCompileCore

tinf = @snoop_inference Enzyme.autodiff(Enzyme.Reverse, q, Active(2.0))
using SnoopCompile
using AbstractTrees
out=open("dat.txt", "w")
print_tree(out, tinf, maxdepth=100)
close(out)

out=open("inv.txt", "w")
@show invalidation_trees(tinf)
println(out, invalidation_trees(tinf))
close(out)

#@testset "Byref Tape" begin
#    @test Enzyme.autodiff(Enzyme.Reverse, q, Active(2.0))[1][1] ≈ 104.0
#end

end # ReverseRules
vchuravy commented 2 days ago

I htink for SnoopCompile to work you need to be on 1.11, since it then should see the work done by Enzyme.

vchuravy commented 1 day ago

We should also add this to the benchmarks.

wsmoses commented 1 day ago

yeah fair, and to be clear this is just one of our internal tests I noticed was running like 30seconds [which is clearly needless here]

wsmoses commented 1 day ago

What's the right way to add, since presumably all we care about is the first run/compile, but BenchmarkTools removes that