Closed avik-pal closed 2 months ago
@prbzrg can you try this branch? Nested AD should be much faster now
You need to rewrite the code with StatefulLuxLayers
using ComponentArrays, Lux, Random, Zygote, ForwardDiff
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
function fn1(nn, z, r, st)
smodel = StatefulLuxLayer(nn, z, st)
return sum(first(Zygote.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, r), z)))
end
fn1(nn, ps, r, st)
Zygote.gradient(fn1, nn, ps, r, st)
function fn2(nn, z, r, st)
smodel = StatefulLuxLayer(nn, z, st)
return sum(first(Zygote.jacobian(Base.Fix1(smodel, r), z)))
end
fn2(nn, ps, r, st)
Zygote.gradient(fn2, nn, ps, r, st)
Attention: Patch coverage is 85.10638%
with 14 lines
in your changes are missing coverage. Please review.
Project coverage is 87.86%. Comparing base (
18b6efe
) to head (aaa3b73
). Report is 4 commits behind head on main.
Files | Patch % | Lines |
---|---|---|
src/helpers/nested_ad.jl | 78.68% | 13 Missing :warning: |
ext/LuxForwardDiffExt.jl | 96.00% | 1 Missing :warning: |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
I tried it. It works and yes, it's faster.
Locally Tests Pass :tada:, once they pass on CI I will merge and tag a release