dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

Repeated calls to replace! have quadratic complexity #142

Open mschauer opened 6 months ago

mschauer commented 6 months ago

https://github.com/dfdx/Yota.jl/blob/54dc7127569feb4d29960ef8f241cef5a739b6a5/src/grad.jl#L152

This is because the remainder of the band is rewritten every time if I understand correctly. Maybe there is a potential efficiency gain here. Ignore if already considered

dfdx commented 6 months ago

Yes, it's a known issue and a notable historical artifact. Initially, Yota was designed to work with its custom rules that required only adding new operations onto the tape and not replacing anything. However, with ChainRules we have to replace a single operation:

%2 = foo(%1)

with 3 operations:

%2 = rrule(foo, %1)
%3 = getfield(%2, 1)
%4 = getfield(%2, 2)

Since variable IDs change, we need to rebind!() all the affected operations (+ context), which by itself has quadratic complexity.

To avoid this effect, we should either drop ChainRules, or at least redesign tape to support multiple output variables. Unfortunately, both of these options are unrealistic at this point of development.

mschauer commented 6 months ago

Thanks, for now I have my

%2 = my_rule(foo, %1)
%3 = getfield(%2, 1)
%4 = getfield(%2, 2)

in the Umlaut.record_primitive! code. Perhaps one can still do it in a single pass without redesigning the tape by avoiding the replace! function and keeping a dictionary of ID changes, I'll think about it. By the way, I wrote a message on Slack. Thank you for Umlaut.jl!

dfdx commented 6 months ago

True! I experimented with a single-pass AD for some time, but eventually decided to keep them separate because of corner cases. For example, Julia has a pretty special way to represent vararg functions:

julia> foo(xs) = print(xs...)
foo (generic function with 2 methods)

julia> Umlaut.getcode(foo, (Vector{Float64},))
1 1 ─ %1 = Core._apply_iterate(Base.iterate, Main.print, _2)::Core.Const(nothing) 
  │
  └──      return %1  

Usually, you don't want to handle things like Core._apply_iterate() yourself, check for consistency between inputs, etc., so Umlaut does most of the hard parts automatically:

julia> trace(foo, [1, 2, 3.0])
1.02.03.0(nothing, Tape{Umlaut.BaseCtx}
  inp %1::typeof(foo)
  inp %2::Vector{Float64}
  %3 = check_variable_length(%2, 3, 2)::Nothing 
  %4 = __to_tuple__(%2)::Tuple{Float64, Float64, Float64} 
  %5 = getfield(%4, 1)::Float64 
  %6 = getfield(%4, 2)::Float64 
  %7 = getfield(%4, 3)::Float64 
  %8 = print(%5, %6, %7)::Nothing 
)

But if you override too many of the "internal" functions like record_primitive!(), you risk to break this behavior. After a number of such issues, I decided to keep the simpler 2-stage design with forward and reverse pass separate.

On the other hand, you don't to cover all corner cases, so you have good chances to get it done!