Closed JakobAsslaender closed 2 months ago
First of all, you don't want to splat for performance. If you just use concatenation instead, this goes away.
function dudt_(u, p, t)
st(p)([u; 1f1])
end
works fine, and is much better for performance. So just use that.
The underlying issue with your original code is:
using Flux
ann = Chain(Dense(5, 50, tanh), Dense(50, 4))
p, st = Flux.destructure(ann)
using Zygote
function dudt_(u, p, t)
st(p)([u..., 1f1])
end
out, back = Zygote.pullback(dudt_, u0, p, 0f0)
d_u, d_p, d_t = back(rand(4))
typeof(d_u) # NTuple
Which I'll open as an upstream issue, but for now, this is closed because if you just change the function you're good.
Describe the bug 🐞 Hi, in the below MWE, the function _vecjacobian! seems to create
dλ
of typeVector
andtmp1
of typeNTuple
and there is no methodrecursive_copyto!(::AbstractVector, ::Tuple)
.Adding the method
resolves the error, but I am unsure if this is the proper bug fix.
Expected behavior No error ;)
Minimal Reproducible Example 👇
Error & Stacktrace ⚠️
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()
Additional context It could easily be that the bug is in my script – apologies if this is the case. If my bugfix is actually the right way to go, I would be happy to create a PR.
Thanks for looking into it!