LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
475 stars 57 forks source link

Native Nested AD support for Lux Models #598

Closed avik-pal closed 4 months ago

avik-pal commented 4 months ago

I will integrate it will DifferentiationInterface later on.

This must go in before #591 because that breaks almost all Zygote over Zygote support.

Simple Benchmarks

using Lux, Zygote, Random

model = Chain(Dense(2, 4, gelu), Dense(4, 4, gelu), Dense(4, 2))

ps, st = Lux.setup(Xoshiro(), model)
x = randn(Float32, 2, 3)

function loss_function(model, x, ps, st)
    smodel = StatefulLuxLayer(model, ps, st)
    return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, x)))
end

@benchmark Zygote.gradient($loss_function, $model, $x, $ps, $st)

New Timings

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   87.777 μs …   4.749 ms  ┊ GC (min … max): 0.00% … 92.06%
 Time  (median):      95.504 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   105.854 μs ± 136.946 μs  ┊ GC (mean ± σ):  4.16% ±  3.18%

    ██▅▂                                                         
  ▂▅████▆▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  87.8 μs          Histogram: frequency by time          191 μs <

 Memory estimate: 148.48 KiB, allocs estimate: 576.

Old Timings

BenchmarkTools.Trial: 9 samples with 1 evaluation.
 Range (min … max):  538.090 ms … 698.453 ms  ┊ GC (min … max): 1.50% … 2.56%
 Time  (median):     629.568 ms               ┊ GC (median):    1.61%
 Time  (mean ± σ):   627.763 ms ±  55.790 ms  ┊ GC (mean ± σ):  1.91% ± 0.46%

  █      █               █      █   █      █         █     █  █  
  █▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁█▁▁▁█▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁█▁▁▁▁▁█▁▁█ ▁
  538 ms           Histogram: frequency by time          698 ms <

 Memory estimate: 101.37 MiB, allocs estimate: 2643156.
codecov[bot] commented 4 months ago

Codecov Report

Attention: Patch coverage is 0% with 134 lines in your changes are missing coverage. Please review.

Project coverage is 80.82%. Comparing base (2899de4) to head (0ab2af7). Report is 1 commits behind head on main.

:exclamation: Current head 0ab2af7 differs from pull request most recent head 65797aa. Consider uploading reports for the commit 65797aa to get more accurate results

Files Patch % Lines
ext/LuxForwardDiffExt.jl 0.00% 78 Missing :warning:
ext/LuxZygoteExt.jl 0.00% 51 Missing :warning:
src/utils.jl 0.00% 5 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #598 +/- ## ========================================== - Coverage 87.60% 80.82% -6.79% ========================================== Files 40 41 +1 Lines 2082 2216 +134 ========================================== - Hits 1824 1791 -33 - Misses 258 425 +167 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.