dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

`Umlaut.getcode` making invalid call #113

Closed cscherrer closed 1 year ago

cscherrer commented 2 years ago

Hi, I think I'm seeing a bug in Umlaut when called from Yota. Here's a MWE:

using TransformVariables, Umlaut, Yota

tr = as((a=asℝ, b = asℝ))
function f(x)
    nt = transform(tr, x)
    nt.a + nt.b
end

# This works
val, tape = trace(f, zeros(2))

# This too
val, tape = trace(f, zeros(2); ctx=Yota.GradCtx)

# Throws `ERROR: Code for this Method is not available.`
grad(f, zeros(2))

To get some idea where it's breaking, ...

julia> Umlaut.print_stack_trace()
[1] _transform_tuple(flag::TransformVariables.LogJacFlag, x::AbstractVector, index, ts) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/aggregation.jl:163
[2] _transform_tuple(flag::TransformVariables.LogJacFlag, x::AbstractVector, index, ts) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/aggregation.jl:163
[3] transform_tuple(flag::TransformVariables.LogJacFlag, tt::Tuple{Vararg{TransformVariables.AbstractTransform, N}} where N, x, index) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/aggregation.jl:175
[4] transform_with(flag::TransformVariables.LogJacFlag, tt::TransformVariables.TransformTuple{<:NamedTuple}, x, index) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/aggregation.jl:227
[5] transform(t::TransformVariables.VectorTransform, x::AbstractVector) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/generic.jl:265
[6] f(x) in Main at REPL[5]:1

To get some more detail, I looked at the original stack trace, which includes a call to Umlaut.getcode. Adding a line @show f, types to that function, I see that it's trying to call code_lowered(tuple, (Float64,)), which leads to the crash.

dfdx commented 2 years ago

Looks like tuple is somewhat special function:

julia> foo(args...) = tuple(args...)
foo (generic function with 1 method)

julia> tuple(1.2)
(1.2,)

julia> foo(1.2)
(1.2,)

julia> code_lowered(tuple, (Float64,))
Core.CodeInfo[]

julia> code_lowered(foo, (Float64,))
1-element Vector{Core.CodeInfo}:
 CodeInfo(
1 ─ %1 = Core._apply_iterate(Base.iterate, Main.tuple, args)
└──      return %1
)

So in general the error makes sense to me. However, it's unclear why the tracer tried to recurse into tuple since it is marked as a primitive in most contexts:

julia> Umlaut.isprimitive(Umlaut.BaseCtx(), tuple, 1.2)
true

julia> Umlaut.isprimitive(Umlaut.BaseCtx(), tuple, 1.2)
true

I'll try to debug it in the next couple of days.

dfdx commented 2 years ago

Could you please post the package versions? With the following specs:

(Yota) pkg> st
     Project Yota v0.7.3
      Status `~/work/Yota/Project.toml`
  [082447d4] ChainRules v1.28.1
  [d360d2e6] ChainRulesCore v1.14.0
  [26cc04aa] FiniteDifferences v0.12.24
  [872c559c] NNlib v0.8.4
  [92992a2b] Umlaut v0.2.4
  [37e2e46d] LinearAlgebra
  [9a3f8284] Random
  [10745b16] Statistics
  [8dfed614] Test
  [cf7118a7] UUIDs

I see a different error:

ERROR: ArgumentError: argument is not a generic function
Stacktrace:
  [1] which(f::Any, t::Any)
    @ Base ./reflection.jl:1320
  [2] trace(f::Function, args::Float64; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/ESm8s/src/trace.jl:341
  [3] #gradtape#90
    @ ~/work/Yota/src/grad.jl:243 [inlined]
  [4] make_rrule(f::Function, args::Float64)
    @ Yota ~/work/Yota/src/cr_api.jl:128
  [5] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Float64)
    @ Yota ~/work/Yota/src/cr_api.jl:170
  [6] rrule(::Yota.YotaRuleConfig, ::typeof(Core._apply_iterate), ::typeof(iterate), ::typeof(tuple), ::Tuple{Float64}, ::Tuple{})
    @ Yota ~/work/Yota/src/rulesets.jl:28
cscherrer commented 2 years ago
  [84d833dd] TransformVariables v0.6.2
  [92992a2b] Umlaut v0.2.4
  [cd998857] Yota v0.7.3
julia> versioninfo()
Julia Version 1.8.0-rc1
Commit 6368fdc656 (2022-05-27 18:33 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: 32 × AMD Ryzen Threadripper 2950X 16-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, znver1)
  Threads: 1 on 32 virtual cores
dfdx commented 2 years ago

Technical note, MRE (julia-1.8.0-rc1):

using TransformVariables, Umlaut, Yota

trace(TransformVariables._transform_tuple, TransformVariables.NoLogJac(), zeros(2), 2, (asℝ,); ctx=Yota.GradCtx())
dfdx commented 2 years ago

I understand the problem, but my naive fix in #114 didn't work as expected. As a workaround right now and here, you can define this rule:

function ChainRulesCore.rrule(::typeof(tuple), args...)
    y = tuple(args...)
    return y, dy -> (NoTangent(), collect(dy)...)
end

Note that you still will need a few more rrules (at least first(::NamedTuple)).


In essence, this code touches a pretty sensitive piece of codebase that dynamically generates new rrules. This code works fine with rules defined in ChainRules.jl (or just without a specific config), but the rule for tuple(args...) only exists in Yota and is not processed correctly by rrule_via_ad(). I tried to fix the invocation, but it led to an endless recursion between rrule() and rrule_via_ad(), and it isn't obvious how to make them flexible and friendly to each another. So it's kind of design question that I need to sleep over.

As even more minimal reproducible examples:

trace(x -> Core._apply_iterate(Base.iterate, Core.tuple, x), 2.0; ctx=Yota.GradCtx())
Yota.rrule_via_ad(Yota.YotaRuleConfig(), Core.tuple, 2.0)
cscherrer commented 2 years ago

Thanks @dfdx , this is very helpful!

Note that you still will need a few more rrules

I'm not sure how specific this is to Yota (maybe it's a more general ChainRules question?) but do you know of a good strategy for knowing which rrules should be defined, first to get things working and then to improve performance?

dfdx commented 2 years ago

Usually, when an rrule is missing, you get an error similar to this:

ERROR: No deriative rule found for op %228 = first(%226)::NamedTuple{(:a, :b), Tuple{Float64, Float64}}, try defining it using 

        ChainRulesCore.rrule(::typeof(first), ::Tuple{NamedTuple{(:a, :b), Tuple{Float64, Float64}}, TransformVariables.NoLogJac, Int64}) = ...

So the general rule is just to try it out and see what's missing.

In your case, though, Yota hit double issue - not only the rrule was missing, but also AD went through the valley of the shadow of recursive rrule_via_ad, which led to a complex combination of otherwise harmless corner cases.

(I'm a bit overwhelmed these days, but I remember this issue isn't fully resolved yet)

cscherrer commented 2 years ago

Thanks @dfdx , I think I see. I'll leave this open so we remember the tuple case isn't fully addressed yet. But it looks like your suggestion should work to avoid the issue.

dfdx commented 1 year ago

rrule() for tuple is now in the library and tested with several use cases, so I close this issue. Note that the following still doesn't work:

Using tuple in any other context should work fine now, e.g.:

julia> grad(x -> sum(tuple(x, x)), 2.0)
(4.0, (ChainRulesCore.ZeroTangent(), 2.0))

Feel free to reopen if you encounter any more issues.