JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 56 forks source link

Record `Broadcast.broadcasted` instead of `Broadcast.broadcast` #215

Open torfjelde opened 1 year ago

torfjelde commented 1 year ago

IIUC ReverseDiff records broadcast and uses ForwardDiff to specialize further on broadcasted statements whenever possible, leading to much better performance than if one were to trace through all the operations using ReverseDiff.TrackedReal.

Unfortunately this means that once one tries to make use of Broadcast.broadcasted, i.e. lazy broadcasting, this is not recorded and we end up taking the less desirable path of tracing through the broadcast using ReverseDiff.TrackedReal:

julia> using ReverseDiff

julia> f(x) = sum(exp.(x))
f (generic function with 1 method)

julia> f_tape = ReverseDiff.GradientTape(f, (rand(10, ),))
typename(ReverseDiff.GradientTape)(f)

julia> g(x) = sum(Broadcast.instantiate(Broadcast.broadcasted(exp, x)))
g (generic function with 1 method)

julia> g_tape = ReverseDiff.GradientTape(g, (rand(10, ),))
typename(ReverseDiff.GradientTape)(g)

julia> length(g_tape.tape)
19

julia> g_tape.tape
19-element Vector{ReverseDiff.AbstractInstruction}:
 ScalarInstruction(exp):
  input:  TrackedReal<76u>(0.5057895559423533, 0.0, 6vX, 1, 5d7)
  output: TrackedReal<AW2>(1.6582943198420965, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.6582943198420965)
 ScalarInstruction(exp):
  input:  TrackedReal<KUi>(0.13213345349262395, 0.0, 6vX, 2, 5d7)
  output: TrackedReal<JmS>(1.141260614319831, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.141260614319831)
 ScalarInstruction(+):
  input:  (TrackedReal<AW2>(1.6582943198420965, 0.0, 6vX, ---),
           TrackedReal<JmS>(1.141260614319831, 0.0, 6vX, ---))
  output: TrackedReal<Cng>(2.7995549341619275, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<GdN>(0.034478177830953305, 0.0, 6vX, 3, 5d7)
  output: TrackedReal<DNm>(1.03507944045113, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.03507944045113)
 ScalarInstruction(+):
  input:  (TrackedReal<Cng>(2.7995549341619275, 0.0, 6vX, ---),
           TrackedReal<DNm>(1.03507944045113, 0.0, 6vX, ---))
  output: TrackedReal<DEh>(3.8346343746130573, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<CEJ>(0.04867133616730335, 0.0, 6vX, 4, 5d7)
  output: TrackedReal<GVI>(1.0498752380105207, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.0498752380105207)
 ScalarInstruction(+):
  input:  (TrackedReal<DEh>(3.8346343746130573, 0.0, 6vX, ---),
           TrackedReal<GVI>(1.0498752380105207, 0.0, 6vX, ---))
  output: TrackedReal<7HD>(4.884509612623578, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<Ft8>(0.8637862831888328, 0.0, 6vX, 5, 5d7)
  output: TrackedReal<5lH>(2.3721252497772487, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(2.3721252497772487)
 ScalarInstruction(+):
  input:  (TrackedReal<7HD>(4.884509612623578, 0.0, 6vX, ---),
           TrackedReal<5lH>(2.3721252497772487, 0.0, 6vX, ---))
  output: TrackedReal<HpK>(7.256634862400826, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<Gmo>(0.0039196786165185404, 0.0, 6vX, 6, 5d7)
  output: TrackedReal<1R0>(1.0039273706035023, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.0039273706035023)
 ScalarInstruction(+):
  input:  (TrackedReal<HpK>(7.256634862400826, 0.0, 6vX, ---),
           TrackedReal<1R0>(1.0039273706035023, 0.0, 6vX, ---))
  output: TrackedReal<8hX>(8.260562233004329, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<L8R>(0.9153223594101434, 0.0, 6vX, 7, 5d7)
  output: TrackedReal<4qL>(2.4975802406432295, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(2.4975802406432295)
 ScalarInstruction(+):
  input:  (TrackedReal<8hX>(8.260562233004329, 0.0, 6vX, ---),
           TrackedReal<4qL>(2.4975802406432295, 0.0, 6vX, ---))
  output: TrackedReal<17T>(10.758142473647558, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<JYS>(0.15063946146751517, 0.0, 6vX, 8, 5d7)
  output: TrackedReal<6gL>(1.1625774285521715, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.1625774285521715)
 ScalarInstruction(+):
  input:  (TrackedReal<17T>(10.758142473647558, 0.0, 6vX, ---),
           TrackedReal<6gL>(1.1625774285521715, 0.0, 6vX, ---))
  output: TrackedReal<Dvu>(11.92071990219973, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<B28>(0.3010502862135006, 0.0, 6vX, 9, 5d7)
  output: TrackedReal<5Ku>(1.3512772904478805, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.3512772904478805)
 ScalarInstruction(+):
  input:  (TrackedReal<Dvu>(11.92071990219973, 0.0, 6vX, ---),
           TrackedReal<5Ku>(1.3512772904478805, 0.0, 6vX, ---))
  output: TrackedReal<2ZC>(13.27199719264761, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<Cpk>(0.02748173107946794, 0.0, 6vX, 10, 5d7)
  output: TrackedReal<9uF>(1.0278628369912384, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.0278628369912384)
 ScalarInstruction(+):
  input:  (TrackedReal<2ZC>(13.27199719264761, 0.0, 6vX, ---),
           TrackedReal<9uF>(1.0278628369912384, 0.0, 6vX, ---))
  output: TrackedReal<J8N>(14.299860029638849, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])

julia> @benchmark ReverseDiff.gradient($f, $(randn(1000)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   9.308 μs …  1.218 ms  ┊ GC (min … max): 0.00% … 95.95%
 Time  (median):     10.105 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.275 μs ± 36.408 μs  ┊ GC (mean ± σ):  9.10% ±  3.05%

  ▇██▇▆▅▅▄▃▁                                       ▁▁▁▁▁▁     ▂
  ██████████▇▅▄▅▄▄▃▃▄▅▅▅▄▅▄▄▁▃▄▁▄▃▃▁▁▁▁▁▁▃▄▃▁▃▄▅▅▆████████▇▇▇ █
  9.31 μs      Histogram: log(frequency) by time      27.8 μs <

 Memory estimate: 55.83 KiB, allocs estimate: 15.

julia> @benchmark ReverseDiff.gradient($g, $(randn(1000)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  151.875 μs …   3.504 ms  ┊ GC (min … max):  0.00% … 94.23%
 Time  (median):     158.858 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   183.434 μs ± 253.436 μs  ┊ GC (mean ± σ):  12.42% ±  8.45%

  ▄▆▆█▇▇▅▄▃▂▁▁▁                                                 ▂
  ██████████████▇▆▇▆▆▇▆▆▅▅▄▅▆▃▄▄▃▁▄▄▃▁▃▄▁▁▁▁▁▁▄▃▃▅▄▄▅▄▄▃▅▅▃▆▆▆▆ █
  152 μs        Histogram: log(frequency) by time        259 μs <

 Memory estimate: 374.53 KiB, allocs estimate: 8009.

The overhead can of course be lowered if the tape is compiled:

julia> x = randn(1000);

julia> inputs = (x,); results = (similar(x),); cfg = ReverseDiff.GradientConfig(inputs);

julia> g_tape = ReverseDiff.GradientTape(g, inputs);

julia> compiled_g_tape = ReverseDiff.compile(g_tape)
typename(ReverseDiff.CompiledTape)(g)

julia> @benchmark ReverseDiff.gradient!($results, $compiled_g_tape, $inputs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  51.798 μs … 117.899 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     54.679 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   55.396 μs ±   4.165 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

    ▃▂▅▁▁▇█▂ ▃▃       ▁              ▂    ▁▁                   ▂
  ▇▇████████▄███▆▇▇▆▆▇█▆▄▄▇▇▄▃▃▆▇▃▁▁██▆▁▁▃██▄▁▄▃▆▅▁▅▃▁▅▅▅▅▃▃▄▅ █
  51.8 μs       Histogram: log(frequency) by time      73.1 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

but it's still slower than ForwardDiff (this of course varies wrt. input size, etc. but I'm guessing this perf difference is well-established given that basically all reverse-AD frameworks in Julia make use of ForwardDiff for broadcasting).

Would it be possible to record broadcasted instead of broadcast in ReverseDiff.jl (this is the way it's done in Zygote.jl)?