FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

Strip zygote frames from mutation error stack trace #1501

Open LilithHafner opened 5 months ago

LilithHafner commented 5 months ago

Motivation and description

When differentiating something complicated which contains mutation, it can be hard to know exactly where the mutation is. In this example, the mutation is tucked away inside the ComponentArray constructor, and in a larger example (e.g. https://github.com/DARPA-ASKEM/sciml-service/issues/141) it might be hard to figure that out.

It would be very helpful if the stack trace provided the exact location of the mutation that triggers this error, rather than interleaving that stack trace with zygote frames. Failing that, it would at least by nice to inform the user that they should look at every third frame to figure out where in their code the mutation is.

julia> using ComponentArrays, Zygote

julia> function f(x)
           ca = ComponentArray(var=x)
           ca.var
       end
f (generic function with 1 method)

julia> Zygote.jacobian(f, [1,2,3])
ERROR: Mutating arrays is not supported -- called push!(Vector{Any}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:70
  [3] (::Zygote.var"#547#548"{Vector{Any}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:89
  [4] (::Zygote.var"#2643#back#549"{Zygote.var"#547#548"{Vector{Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] merge
    @ ./namedtuple.jl:371 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(merge), @NamedTuple{}, Base.Generator{…}}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [7] make_idx
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:170 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [9] make_carray_args
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:151 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [11] make_carray_args
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:144 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Vector{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [13] ComponentArray
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:64 [inlined]
 [14] #ComponentArray#21
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:67 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ComponentVector{Int64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [16] ComponentArray
    @ ~/.julia/dev/ComponentArrays/src/componentarray.jl:67 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ComponentVector{Int64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [18] f
    @ ./REPL[2]:2 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [20] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [21] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [22] call_composed
    @ ./operators.jl:1045 [inlined]
 [23] (::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{…}, Tuple{…}, @Kwargs{}}, Any})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [24] call_composed
    @ ./operators.jl:1044 [inlined]
 [25] #_#103
    @ ./operators.jl:1041 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [27] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [28] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [29] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [30] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [31] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [32] withjacobian(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/grad.jl:150
 [33] jacobian(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/grad.jl:128
 [34] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.
ToucheSir commented 5 months ago

The story with Zygote stacktraces is more complex than described and could use a little explaining. If we use this stacktrace for illustration:

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:70

The first two stackframes are what you'd expect, common error reporting code. More interesting is the next two:

  [3] (::Zygote.var"#547#548"{Vector{Any}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:89
  [4] (::Zygote.var"#2643#back#549"{Zygote.var"#547#548"{Vector{Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72

As you've correctly identified, [3] has the actual rule we should look at. So what is [4]? That would be the rule machinery itself at https://github.com/FluxML/ZygoteRules.jl/blob/f9bf0e367fa259c5aa68f0e14ccbf2125d734bd6/src/adjoint.jl#L72. Not very helpful.

Now for the surprising revelation: there is actually no "interleaving of Zygote frames in this stacktrace". From [2] to [33], it's all Zygote. But how can that be when we have frames like this?

  [5] merge
    @ ./namedtuple.jl:371 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(merge), @NamedTuple{}, Base.Generator{…}}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0

Essentially, Zygote's generated functions can spoof line numbers from the original function, so that [5] merge frame is actually the same call as [6] (the AD-generated pullback).

I don't know why this was done. It was either intentional to help with looking up the original function since the genfunc code provides little info to work with, or the line info is sticking around by accident. This snippet also shows the limitations Zygote has around stackframe printing. Ideally, we'd want the call info of [5] with the file name and line number of [6] to make stacktraces would be shorter and cleaner. I'm assuming this is possible, but the main problem is that Zygote's internals are a PITA to work with (mostly because of IRTools, IMO).