MilesCranmer / DispatchDoctor.jl

The dispatch doctor prescribes type stability
Apache License 2.0
138 stars 7 forks source link

Mark the stability checks as non differentiable? #32

Closed avik-pal closed 3 months ago

avik-pal commented 3 months ago

See https://buildkite.com/julialang/luxcore-dot-jl/builds/88#018ff4e6-b8ea-4fee-9c5b-8759bfdb35ab/316-4214

Introducing @stable in https://github.com/LuxDL/LuxCore.jl/pull/28 makes the code non-differentiable.

We should probably mark the checking code block as CRC non_differentiable. Not sure what happens with Enzyme

MilesCranmer commented 3 months ago

With Enzyme it seems to work — I have an integration test for it: https://github.com/MilesCranmer/DispatchDoctor.jl/blob/02b46f6060c84c632dedf5e602dd3ca6a2c72d95/test/enzyme.jl#L17

What’s the right way to hide the instability check from other AD backends? And also why doesn’t it work out of the box? It’s not returning anything different so I am surprised.

avik-pal commented 3 months ago

Something like ChainRulesCore.@ignore_derivatives().

Upon more testing problem seems to arise only for default mode as warn. A minimal example

using Zygote, DispatchDoctor

@stable default_mode="warn" foo(x) = sum(x)

Zygote.gradient(foo, rand(10))

The problem is from the try/catch in the code

begin
    #= logging.jl:371 =#
    try
        #= logging.jl:372 =#
        var"#99#msg" = (DispatchDoctor._Errors.TypeInstabilityWarning)("`f`", "REPL[42]:1", (), (;), () .=> (), var"##f_return_type#258")
        #= logging.jl:373 =#
        var"#100#kwargs" = (;)
        #= logging.jl:374 =#
        true
    catch var"#113#err"
        #= logging.jl:376 =#
        Base.invokelatest(Base.CoreLogging.logging_error, var"#95#logger", var"#91#level", var"#94#_module", var"#93#group", var"#96#id", var"#97#file", var"#98#line", var"#113#err", true)
        #= logging.jl:377 =#
        false
    end
end && Base.CoreLogging.invokelatest(Base.CoreLogging.handle_message, var"#95#logger", var"#91#level", var"#99#msg", var"#94#_module", var"#93#group", var"#96#id", var"#97#file", var"#98#line"; var"#100#kwargs"...)
MilesCranmer commented 3 months ago

Hm, interesting. I wonder if this should be patched here, or within CRC? It seems like you would always want it to skip @warn

avik-pal commented 3 months ago

Oh wait, I see it is purely generated by Logging.

But I am not entirely sure how to patch this in CRC because the generated code has try ... catch and doesn't really call into a function that has try ... catch

cc @oxinabox if you have any thoughts on whether this can be done in CRC

MilesCranmer commented 3 months ago

Here's a MWE without DispatchDoctor; just the warning:

julia> function f(x)
           @warn "$(x)"
           x
       end
f (generic function with 1 method)

julia> Zygote.gradient(f, 0.5)
ERROR: Compiling Tuple{typeof(f), Float64}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Seems like a bug in Zygote.jl? (Well, not a bug, just something that needs a workaround)

MilesCranmer commented 3 months ago

https://github.com/FluxML/Zygote.jl/issues/269 🙂

MilesCranmer commented 3 months ago

Okay seeing that that hasn't been fixed for nearly 5 years, I'll just patch here. Will make a PR soon.