CDCgov / Rt-without-renewal

https://cdcgov.github.io/Rt-without-renewal/
Apache License 2.0
19 stars 3 forks source link

Test `Tapir` AD for EpiAware models #454

Open SamuelBrand1 opened 1 month ago

SamuelBrand1 commented 1 month ago

This could be working now https://github.com/TuringLang/Turing.jl/pull/2289/files#review-changes-modal .

Tapir AD looks a really good advance on ReverseDiff, so this would be good https://github.com/compintell/Tapir.jl

seabbs commented 1 month ago

We should look at adding to benchmarking

SamuelBrand1 commented 1 month ago

A feature to look for is if accumulate diff calls get a boost. A rrule exists for accumulate here but its unclear to me that ReverseDiff get this.

yebai commented 1 month ago

cc @willtebbutt who can help.

willtebbutt commented 1 month ago

Thanks for tagging me in this @yebai . To know for sure whether Tapir.jl will be of use I'd have to know a bit more about exactly what problems you're interested in being able to differentiate, but a quick demo involving accumulate:

using Pkg
Pkg.activate(; temp=true)
Pkg.add(["BenchmarkTools", "ReverseDiff", "Tapir"])
using BenchmarkTools, ReverseDiff, Tapir

f(x) = sum(identity, accumulate(+, x))
x = randn(1_000_000);

@benchmark f($x)

tape = ReverseDiff.compile(ReverseDiff.GradientTape(f, x));
gradient_storage = zero(x);
@benchmark ReverseDiff.gradient!($gradient_storage, $tape, $x)

rule = Tapir.build_rrule(f, x)
@benchmark Tapir.value_and_gradient!!($rule, f, $x)

yields

julia> @benchmark f($x)
BenchmarkTools.Trial: 1320 samples with 1 evaluation.
 Range (min … max):  1.747 ms … 280.507 ms  ┊ GC (min … max):  0.00% … 99.13%
 Time  (median):     2.330 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   3.780 ms ±   9.231 ms  ┊ GC (mean ± σ):  34.48% ± 19.17%

  ▄█      ▁                                                    
  ███▁▁▁▃██▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▁▁▃▆▃▄▅▄▆▃▄▄▆▅▅▃▁▅▅ ▇
  1.75 ms      Histogram: log(frequency) by time      27.5 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 2.

julia> @benchmark ReverseDiff.gradient!($gradient_storage, $tape, $x)
BenchmarkTools.Trial: 38 samples with 1 evaluation.
 Range (min … max):  127.823 ms … 140.864 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     133.395 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   133.754 ms ±   3.520 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▁   ▁ █  ██▁▁▁▁    █  ██ ▁█  ▁▁▁  ▁  ▁▁ ▁█ ▁  ▁   ▁█▁ ▁     ▁  
  █▁▁▁█▁█▁▁██████▁▁▁▁█▁▁██▁██▁▁███▁▁█▁▁██▁██▁█▁▁█▁▁▁███▁█▁▁▁▁▁█ ▁
  128 ms           Histogram: frequency by time          141 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark Tapir.value_and_gradient!!($rule, f, $x)
BenchmarkTools.Trial: 106 samples with 1 evaluation.
 Range (min … max):  37.834 ms … 589.776 ms  ┊ GC (min … max):  0.00% … 91.95%
 Time  (median):     39.679 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   48.074 ms ±  56.167 ms  ┊ GC (mean ± σ):  16.70% ± 14.45%

  █▆                                                            
  ███▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▆ ▄
  37.8 ms       Histogram: log(frequency) by time       155 ms <

 Memory estimate: 22.91 MiB, allocs estimate: 272.

(Annoyingly you won't actually be able to run this example for a couple of hours, because I messed something up with the way that Tapir.jl interacts with BenchmarkTools.jl, and have a fix that should be available on the general registry in the next couple of hours -- I had to dev Tapir.jl and checkout to the appropriate branch in order to be able to run this).

There's a decent speed up when compared with ReverseDiff.jl in this case. I'd be interested to know if you've got any other examples that you're keen to try out!

SamuelBrand1 commented 1 month ago

Hey @willtebbutt ,

Thanks for coming over to show this!

The broad outline of our interest here is that we expose constructors for various ways of defining discrete time epidemiological models. Any time-stepping is (generally) done with a scanning function that uses Base.accumulate under the hood to propagate a state forward in time dependent on some other process (think the time varying reproduction number).

When doing inference anything that speeds up the grad calls here is going to be very useful.

willtebbutt commented 1 month ago

Sounds good.

I'm keen to help out, so please do ping me if I can be of use.

SamuelBrand1 commented 1 month ago

So long as you have a moderate-to-high tolerance for stupid questions I'll take you up on that!

wsmoses commented 1 month ago

Also could be fun to try Enzyme.jl at the same time.

In Turing code it generally sees an extra order of magnitude over Tapir and also is increasingly getting adopted by big Julia packages as the new default AD.

yebai commented 1 month ago

it generally sees an extra order of magnitude over Tapir

In my experience, the performance difference between Tapir and Enzyme seems relatively small for Turing models with non-trivial computation. @willtebbutt did an excellent job capitalising on the recent improvements in Julia's compiler API.

wsmoses commented 1 month ago

sounds like another reason to run more benchmarks then :)

SamuelBrand1 commented 1 month ago

Also could be fun to try Enzyme.jl at the same time.

In Turing code it generally sees an extra order of magnitude over Tapir and also is increasingly getting adopted by big Julia packages as the new default AD.

Somewhere on my HD I've got a first pass script to write a simple Renewal epi model aimed at working with Enzyme (based on the code in the Box model, but my day-to-day has been a bit intense.