Closed mayor-slash closed 11 months ago
Try changing result = vcat(dN...,dT)
to not use splatting. Something like vcat(dN, dT)
should work. If you splat, Zygote doesn't preserve the Vector structure
For NeuralODEs most of the sensitivity algorithms assume a flat parameter structure (unlike the default return type which is a nested named tuple). See https://lux.csail.mit.edu/dev/tutorials/intermediate/1_NeuralODE#training, it describes some Adjoints which don't require the flat parameter structure.
Verified locally that the suggested patch works:
function (m::CustomModel)(x, ps, st::NamedTuple)
rezi_T = 1000/x[n_species+1]
Ns = x[1:n_species]
model_inp = vcat(log10.(Ns), rezi_T)
k, st_k = m.k_model(model_inp, ps.k_model, st.k_model)
ks = 10 .^ k
eff_Ns = vec(prod(Ns .^ -min.(stoich_m', 0.), dims=1))
qk = eff_Ns .* ks
dN = vec(sum(qk .* stoich_m, dims=1))
uscp, st_t = m.t_model([rezi_T], ps.t_model, st.t_model)
us = 1e8 .* uscp[1:n_species]
cp = sum((1e4 .* uscp[n_species+1:n_species*2]) .* Ns)
dU = sum(us .* dN)
dT = -dU/cp
result = vcat(dN,dT)
return result, (k_model = st_k, t_model = st_t)
end
I want to create a custom model, that holds two Lux.Chains within. The following Code demonstates this:
Zygote throws a huge error with the MethodError: no method matching +(::NTuple{7, Float64}, ::Vector{Float64})
I am new to Julia but I guess somewhere in the chainrules of Zygote a NamedTuple is created and it can't be added to a normal vector. Is there some bad implementation on my part or is there a reason why ForwardDiff can get the gradient and Zygote doesnt?
Bonus question: Why do I have to convert my parameters (p) to a ComponentVector? I got that from some other tutorials on Neural ODE training.