LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

Improvements to Nested AD #612

Closed avik-pal closed 2 months ago

avik-pal commented 2 months ago
avik-pal commented 2 months ago

@prbzrg can you try this branch? Nested AD should be much faster now

You need to rewrite the code with StatefulLuxLayers

using ComponentArrays, Lux, Random, Zygote, ForwardDiff

nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)

function fn1(nn, z, r, st)
    smodel = StatefulLuxLayer(nn, z, st)
    return sum(first(Zygote.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, r), z)))
end

fn1(nn, ps, r, st)
Zygote.gradient(fn1, nn, ps, r, st)

function fn2(nn, z, r, st)
    smodel = StatefulLuxLayer(nn, z, st)
    return sum(first(Zygote.jacobian(Base.Fix1(smodel, r), z)))
end

fn2(nn, ps, r, st)
Zygote.gradient(fn2, nn, ps, r, st)
codecov[bot] commented 2 months ago

Codecov Report

Attention: Patch coverage is 85.10638% with 14 lines in your changes are missing coverage. Please review.

Project coverage is 87.86%. Comparing base (18b6efe) to head (aaa3b73). Report is 4 commits behind head on main.

Files Patch % Lines
src/helpers/nested_ad.jl 78.68% 13 Missing :warning:
ext/LuxForwardDiffExt.jl 96.00% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #612 +/- ## ========================================== + Coverage 86.73% 87.86% +1.13% ========================================== Files 41 42 +1 Lines 2216 2176 -40 ========================================== - Hits 1922 1912 -10 + Misses 294 264 -30 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

prbzrg commented 2 months ago

I tried it. It works and yes, it's faster.

avik-pal commented 2 months ago

Locally Tests Pass :tada:, once they pass on CI I will merge and tag a release