Closed rmsrosa closed 4 months ago
Zygote doesn't do nested AD well. Lux 0.5.38 used to ship an older version of LuxLib
which allowed nested Zygote calls (for very limited cases restricted to only Dense and Conv) but was considerably slower.
In the recent versions, we have significantly robust nested AD support but that requires some manual rewrite. See https://lux.csail.mit.edu/stable/manual/nested_autodiff for a detailed version. Your usecase can be solved with:
function loss_function(model, ps, st, sample_points)
smodel = StatefulLuxLayer{true}(model, ps, st)
y_pred = smodel(sample_points)
dy_pred = only(Zygote.gradient(sum ∘ smodel, sample_points))
loss = sum(dy_pred .+ y_pred .^2 / 2)
return loss, smodel.st, ()
end
Note that while it might seem like we are still doing Zygote.gradient(Zygote.gradient(...))
to compute the final gradients here, we are instead using a more optimized mixed mode AD combining ForwardDiff and Zygote. We recently had an extended discussion on discourse about this.
Thanks for the quick reply!
I tried your suggestion of loss_function
, but I am still getting an error.
julia> Lux.Training.compute_gradients(vjp_rule, loss_function, sample_points, tstate_org)
ERROR: ArgumentError: type does not have a definite number of fields
I guess I don't understand enough to find my way out of this. I should read the discourse discussion and the link about nested AD to see if I have a better understanding of this.
Are you using the latest version of Lux?
Are you using the latest version of Lux?
Yes:
(current_env) pkg> st
Status `~/Documents/git-repositories/julia/julia_random_stuff/zygote_bug/current_env/Project.toml`
[b2108857] Lux v0.5.58
[3bd65402] Optimisers v0.3.3
[e88e6eb3] Zygote v0.6.70
I will have a look at this today evening
Thank you!
Now it works, and it is indeed much faster!
The following used to work in
Lux v0.5.38
but errors inLux v0.5.57
. In both cases the other packages areOptimisers v0.3.3
andZygote v0.6.70
and withJulia Version 1.10.0
.The error in
Lux v0.5.57
isIn
Lux v0.5.38
it just works:I don't see any mutating array in the
loss_function
. It looks like something changed inLux.Training.compute_gradients
that broke it.