LuxDL / Lux.jl

Elegant and Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
506 stars 63 forks source link

Zygote nested AD failure #604

Closed vavrines closed 6 months ago

vavrines commented 6 months ago

It is mentioned in the manual that "You can use Zygote.jacobian as well but ForwardDiff tends to be more efficient here". Switching to Zygote:

using Lux, LinearAlgebra, Zygote, Random, ComponentArrays

function loss_function1(model, x, ps, st, y)
    smodel = StatefulLuxLayer(model, ps, st)
    ŷ = smodel(x)
    loss_emp = sum(abs2, ŷ .- y)
    J = Zygote.jacobian(smodel, x) # used to be ForwardDiff
    loss_reg = abs2(norm(J))
    return loss_emp + loss_reg
end

model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(Xoshiro(0), model)
x = rand(Xoshiro(0), Float32, 2, 10)
y = rand(Xoshiro(11), Float32, 2, 10)

loss_function1(model, x, ps, st, y)
_, ∂x, ∂ps, _, _ = Zygote.gradient(loss_function1, model, x, ps, st, y)

I got:

ERROR: MethodError: no method matching ForwardDiff.Dual{ForwardDiff.Tag{LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Chain{@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::BatchNorm{true, true, Float32, typeof(identity), typeof(zeros32), typeof(ones32)}, layer_3::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{scale::Vector{Float32}, bias::Vector{Float32}}, layer_3::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{running_mean::Vector{Float32}, running_var::Vector{Float32}, training::Val{true}}, layer_3::@NamedTuple{}}}}}, Float64}, Float64, 1}(::Float32, ::ForwardDiff.Partials{1, Float64})

Closest candidates are:
  ForwardDiff.Dual{T, V, N}(::Number) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:78
  ForwardDiff.Dual{T, V, N}(::Any) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:77
  ForwardDiff.Dual{T, V, N}(::V, !Matched::ForwardDiff.Partials{N, V}) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:17

Env:

Lux v0.5.39
avik-pal commented 6 months ago
function loss_function1(model, x, ps, st, y)
    smodel = StatefulLuxLayer(model, ps, st)
    ŷ = smodel(x)
    loss_emp = sum(abs2, ŷ .- y)
    J = only(Zygote.jacobian(smodel, x)) # used to be ForwardDiff
    loss_reg = abs2(norm(J))
    return loss_emp + loss_reg
end

Zygote.<....> functions return a tuple, so you need to add that only or a first.