EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
440 stars 62 forks source link

Check if function is being called inside `autodiff` #1761

Open avik-pal opened 3 weeks ago

avik-pal commented 3 weeks ago

Conversation from Slack

@avik-pal

In Enzyme can we check if a function is being called within an autodiff? Something equivalent to the following ChainRules version

within_gradient(_) = False()
CRC.rrule(::typeof(within_gradient), x) = True(), _ -> (∂∅, ∂∅)

I tried defining:

function EnzymeRules.forward(
        ::EnzymeCore.Const{typeof(within_gradient)}, ::Type{RT}, x) where {RT}
    error("within_gradient")
end

@wsmoses

I think this is getting optimized out lol But yeah we can add something (c++ interface already has so we should bring parity to Julia for that :P) (Julia will still constprop through things with enzyme rules atm)

avik-pal commented 2 weeks ago

@wsmoses is this as simple as adding a ccall somewhere? I can make a PR if you point me to the function

wsmoses commented 2 weeks ago

we have this magic function: https://github.com/EnzymeAD/Enzyme/blob/7f614f43808e5bd3960f3712ac880b38eedc01d6/enzyme/test/Integration/ReverseMode/mycos.c#L39

which has the semantics that __enzyme_iter(x, y) = x if not differentiated, x + y on first order, etc [used for sake of keeping right order of taylor series].

It probably needs a bit of julia integration, but maybe something like this?

avik-pal commented 2 weeks ago

Is this not an "actual" function present in the binary? I tried

function within_autodiff()
    return ccall((:__enzyme_iter, libEnzyme), UInt64, (UInt64, UInt64), 0, 1) != 0
end

But that symbol is not present

wsmoses commented 2 weeks ago

No it’s a fictitious one that Enzyme will replace with a definition (or rather during AD transform as above).

Just like the enzyme pass we may need to tell enzyme to replace uses after AD of the function with the first arg.

On Sat, Aug 31, 2024 at 5:16 PM Avik Pal @.***> wrote:

Is this not an "actual" function present in the binary? I tried

function within_autodiff() return ccall((:__enzyme_iter, libEnzyme), UInt64, (UInt64, UInt64), 0, 1) != 0end

But that symbol is not present

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1761#issuecomment-2323057018, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXGCTGJPWR6GMKNIHTDZUI6CPAVCNFSM6AAAAABNLKIX3SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMRTGA2TOMBRHA . You are receiving this because you were mentioned.Message ID: @.***>

wsmoses commented 2 weeks ago

hm okay I realize this won't quite work the same in Julia since we won't have control over the non differentiated optimizatio pipeline.

I think the solution here is we just need to tell the abstract interpreter to block constprop for a special function which we rewrite if inside an enzyme differentiated context. cc @vchuravy @aviatesk

I know some early work on absint overriding was here: https://github.com/EnzymeAD/Enzyme.jl/pull/1443 but I think needs someone to push it forward atm

wsmoses commented 3 days ago

@avik-pal so this should be basically as simple as defining a new function returning false and then doing the same as https://github.com/EnzymeAD/Enzyme.jl/pull/1839 to change to false [which this interpreter always runs in an autodiff context]