Closed prbzrg closed 2 months ago
It works with Lux v0.5.37
.
How can I do it with Lux v0.5.41
?
Is this a bug? Did I miss something?
Maybe related to https://github.com/LuxDL/Lux.jl/issues/286
pullbacks are hard to differentiate directly. See https://github.com/LuxDL/Lux.jl/issues/610#issuecomment-2085160267. We just need some rrules for DifferentiationInterface.pullback. Zygote.pullback is almost never going to work, unless someone can use some nice trick to write the tangent for the pullback function. DI.pullback on the other hand is quite simple, DEQs.jl already does that
I'm still getting the error with
(br-3) pkg> st
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
[b0b7db55] ComponentArrays v0.15.11
[a0c0ee7d] DifferentiationInterface v0.3.3
[f6369f11] ForwardDiff v0.10.36
[b2108857] Lux v0.5.42
[e88e6eb3] Zygote v0.6.69
[9a3f8284] Random
(br-3) pkg> st --outdated
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
(br-3) pkg> st --outdated -m
Status `D:\Codes\Mine\bug-report\br-3\Manifest.toml`
Even after using DI:
using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)
function fn1(u, p)
z, uJ = DifferentiationInterface.value_and_pullback(x -> snn(x, p), AutoZygote(), u, u)
sum(uJ) + sum(z)
end
fn1(r, ps)
DifferentiationInterface.gradient(x -> fn1(r, x), AutoZygote(), ps)
Use Lux.vector_jacobian_product
I tried:
using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)
function fn1(u, p)
z, uJ = Lux.vector_jacobian_product(x -> snn(x, p), AutoZygote(), u, u)
sum(uJ) + sum(z)
end
fn1(r, ps)
# DifferentiationInterface.gradient(x -> fn1(r, x), AutoZygote(), ps)
Zygote.gradient(fn1, r, ps)
But the problem is still there. 😬
Closures don't work, see the first part in https://lux.csail.mit.edu/stable/manual/nested_autodiff. Also https://lux.csail.mit.edu/stable/api/Lux/utilities#Lux.vector_jacobian_product returns only the vjp not the value and vjp (which can be added later but doesn't affect the code by much)
Thanks, the problem is resolved. final code:
using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
function fn1(u, ps, st)
snn = StatefulLuxLayer(nn, ps, st)
z = snn(u)
uJ = Lux.vector_jacobian_product(snn, AutoZygote(), u, u)
sum(uJ) + sum(z)
end
fn1(r, ps, st)
# DifferentiationInterface.gradient(x -> fn1(r, x, st), AutoZygote(), ps)
Zygote.gradient(fn1, r, ps, st)
Capturing DI would have been the ideal situation but it causes ambiguities and I would have to manually define the functions for all possibilities which will get messy https://github.com/LuxDL/Lux.jl/issues/600#issuecomment-2094566210
Error:
MRE:
Environment: