LuxDL / Lux.jl

Elegant and Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
500 stars 61 forks source link

Zygote gradient fails for Custom Layer #457

Closed mayor-slash closed 11 months ago

mayor-slash commented 11 months ago

I want to create a custom model, that holds two Lux.Chains within. The following Code demonstates this:

using ComponentArrays,Zygote,Lux,StableRNGs,ForwardDiff

# Create reproducible rng
rng = StableRNG(131)

# Define Some constants
species = [     "H",    "H2",   "O",    "OH","  H2O",   "O2",   "N2","T"]
stoich_m = [  [ 2.0,    -1.0,   0.0,    0.0,    0.0,    0.0,    0.0     ],   
              [ -1.0,   0.0,    1.0,    1.0,    0.0,    -1.0,   0.0    ],
              [ 1.0,    -1.0,   0.0,    -1.0,   1.0,    0.0,    0.0    ],
              [ 1.0,   -1.0,   -1.0,    1.0,   0.0,    0.0,    0.0    ]]
stoich_m = hcat(stoich_m..., (-1 .* stoich_m)...)'
stoich_m = stoich_m
n_nodes = 20
n_species = 7
n_react = 8

# Build models that sit inside custom Model

t_model = Lux.Chain(
    Lux.Dense(1, n_nodes, tanh),
    Lux.Dense(n_nodes, n_nodes, tanh),
    Lux.Dense(n_nodes, 2*n_species))

k_model = Lux.Chain(
    Lux.Dense(1+n_species, n_nodes, tanh),
    Lux.Dense(n_nodes, n_nodes, tanh),
    Lux.Dense(n_nodes, n_react))

# Build Struct and associated function
struct CustomModel{k,t} <: Lux.AbstractExplicitContainerLayer{(:k_model, :t_model)}
    k_model::k
    t_model::t
end

function (m::CustomModel)(x, ps, st::NamedTuple)
    rezi_T = 1000/x[n_species+1]
    Ns = x[1:n_species]
    model_inp = vcat(log10.(Ns)..., rezi_T)
    k, st_k = m.k_model(model_inp, ps.k_model, st.k_model)
    ks = 10 .^ k
    eff_Ns =  vec(prod(Ns .^ -min.(stoich_m', 0.), dims=1))
    qk = eff_Ns .* ks
    dN = vec(sum(qk .* stoich_m, dims=1))
    uscp, st_t = m.t_model([rezi_T], ps.t_model, st.t_model)
    us = 1e8 .* uscp[1:n_species]
    cp = sum((1e4 .* uscp[n_species+1:n_species*2]) .* Ns)
    dU = sum(us .* dN)
    dT = -dU/cp
    result = vcat(dN...,dT)
    # println(typeof(model_inp),size(model_inp))
    # println(typeof(k),size(k))
    # println(typeof(ks),size(ks))
    # println(typeof(eff_Ns),size(eff_Ns))
    # println(typeof(qk),size(qk))
    # println(typeof(dN),size(dN))
    # println(typeof(uscp),size(uscp))
    # println(typeof(result),size(result))
    return result, (k_model = st_k, t_model = st_t)
end

# Create Model and run setup to get parameters and state
model = CustomModel(k_model, t_model)
p,st = Lux.setup(rng, model)
p = ComponentVector(p)

# Create random test vector that is strictly positive
x = randn( Float64, 8) .^2 .+ 1e-20

# Create simple "loss"-function to differentiate 
function f(p)
    y = model(x,p,st)[1]
    return sum(y)
end

println(f(p))                   # This works
df = ForwardDiff.gradient(f,p)  # This works
df2 = Zygote.gradient(f, p)     # This doesnt work

Zygote throws a huge error with the MethodError: no method matching +(::NTuple{7, Float64}, ::Vector{Float64})

I am new to Julia but I guess somewhere in the chainrules of Zygote a NamedTuple is created and it can't be added to a normal vector. Is there some bad implementation on my part or is there a reason why ForwardDiff can get the gradient and Zygote doesnt?

Bonus question: Why do I have to convert my parameters (p) to a ComponentVector? I got that from some other tutorials on Neural ODE training.

avik-pal commented 11 months ago

Try changing result = vcat(dN...,dT) to not use splatting. Something like vcat(dN, dT) should work. If you splat, Zygote doesn't preserve the Vector structure

For NeuralODEs most of the sensitivity algorithms assume a flat parameter structure (unlike the default return type which is a nested named tuple). See https://lux.csail.mit.edu/dev/tutorials/intermediate/1_NeuralODE#training, it describes some Adjoints which don't require the flat parameter structure.

avik-pal commented 11 months ago

Verified locally that the suggested patch works:

function (m::CustomModel)(x, ps, st::NamedTuple)
    rezi_T = 1000/x[n_species+1]
    Ns = x[1:n_species]
    model_inp = vcat(log10.(Ns), rezi_T)
    k, st_k = m.k_model(model_inp, ps.k_model, st.k_model)
    ks = 10 .^ k
    eff_Ns =  vec(prod(Ns .^ -min.(stoich_m', 0.), dims=1))
    qk = eff_Ns .* ks
    dN = vec(sum(qk .* stoich_m, dims=1))
    uscp, st_t = m.t_model([rezi_T], ps.t_model, st.t_model)
    us = 1e8 .* uscp[1:n_species]
    cp = sum((1e4 .* uscp[n_species+1:n_species*2]) .* Ns)
    dU = sum(us .* dN)
    dT = -dU/cp
    result = vcat(dN,dT)
    return result, (k_model = st_k, t_model = st_t)
end