ptiede / Comrade.jl

MIT License
43 stars 8 forks source link

Move AD to Enzyme #285

Closed ptiede closed 2 months ago

ptiede commented 1 year ago

Enzyme is fast, works on GPU's, and allows mutation. This branch will try to move Comrade to solely use Enzyme for all its AD needs moving forward. This will involve changing a number of things:

This is not going to be quick. But I have done some preliminary testing on geometric models and it looks like Enzyme should work. For example with for the posterior in black_hole_image.jl we get the following results:

Current Comrade 0.7.1:

Forward Diff

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  232.671 μs …   2.651 ms  ┊ GC (min … max):  0.00% … 84.44%
 Time  (median):     243.484 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   283.548 μs ± 252.489 μs  ┊ GC (mean ± σ):  12.69% ± 12.27%

  █▃                                                            ▁
  ███▆▃▄▁▃▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▆▇█▇ █
  233 μs        Histogram: log(frequency) by time       1.99 ms <
 Memory estimate: 1.19 MiB, allocs estimate: 1301.

Zygote ( So slow :-1: )

BenchmarkTools.Trial: 1874 samples with 1 evaluation.
 Range (min … max):  1.832 ms … 11.902 ms  ┊ GC (min … max):  0.00% … 68.71%
 Time  (median):     2.056 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   2.666 ms ±  1.966 ms  ┊ GC (mean ± σ):  20.68% ± 20.09%

  █▇▇▅▃▂                                                      
  ███████▆▄▄▄▄▁▁▄▄▄▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▅▇▅▇█▅▇█████ █
  1.83 ms      Histogram: log(frequency) by time     9.69 ms <

 Memory estimate: 4.66 MiB, allocs estimate: 31776.

Enzyme Reverse

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  123.492 μs … 386.016 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     135.454 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   136.159 μs ±   7.468 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                         ▃▃      ▃█▆  ▂▁   ▃▂ ▁▁                ▁
  ▆▅▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇▄▃▃▁▃▄██▆▅▆▅▅▅████████▇▇██▇██▇█▇▆▅▅▅▆▅▄▅▁▅▃▃ █
  123 μs        Histogram: log(frequency) by time        146 μs <

 Memory estimate: 32 bytes, allocs estimate: 1.

Enzyme Forward

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  83.218 μs … 247.591 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     91.853 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   94.859 μs ±  10.496 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

         █▃▇▃▁▁▁                                               ▁
  ▆█▆▇▂█▇████████▇▇▇▆▆▆▆▄▅▅▅▅▅▃▅▅▄▄▄▄▃▃▄▃▃▅▃▂▅▄▄▄▅▅▅▆▆▆▆▅▆▅▆▆▆ █
  83.2 μs       Histogram: log(frequency) by time       154 μs <

 Memory estimate: 48 bytes, allocs estimate: 2.

Note that this is only a 10-dimensional model, so we would expect Forward-mode AD systems to win here.

At the end of this pull request we should potentially see some pretty nice speed-ups. Additionally, the code will be non-allocating, which should allow us in another pull request to thread more efficiently.

ptiede commented 2 months ago

This is dead thanks to the new interface