LuxDL / Lux.jl

Elegant and Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
506 stars 63 forks source link

Error in `compute_gradients` when loss already has a `Zygote.gradient` #743

Closed rmsrosa closed 4 months ago

rmsrosa commented 4 months ago

The following used to work in Lux v0.5.38 but errors in Lux v0.5.57. In both cases the other packages are Optimisers v0.3.3 and Zygote v0.6.70 and with Julia Version 1.10.0.

using Lux
using Zygote
using Optimisers
using Random

model = Chain(Dense(1 => 8, sigmoid), Dense(8 => 1))
rng = Xoshiro(12345)

ps, st = Lux.setup(rng, model)

function loss_function(model, ps, st, sample_points)
    sample_points
    y_pred, st = Lux.apply(model, sample_points, ps, st)
    dy_pred = Zygote.gradient(s -> sum(model(s, ps, st)[1]), sample_points)[1]
    loss = sum(dy_pred .+ y_pred .^2 / 2)
    return loss, st, ()
end

opt = Adam(0.01)

tstate_org = Lux.Training.TrainState(rng, model, opt)

vjp_rule = Lux.Training.AutoZygote()

dev_cpu = cpu_device()

sample_points = permutedims(randn(16))

Lux.Training.compute_gradients(vjp_rule, loss_function, sample_points, tstate_org)

The error in Lux v0.5.57 is

(current_env) pkg> st
Status `~/Documents/git_repositories/julia/julia_random_stuff/zygote_bug/Project.toml`
  [b2108857] Lux v0.5.57
  [3bd65402] Optimisers v0.3.3
  [e88e6eb3] Zygote v0.6.70

julia> Lux.Training.compute_gradients(vjp_rule, loss_function, sample_points, tstate_org)
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:70
  [3] (::Zygote.var"#539#540"{Vector{Float64}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] _mapreducedim!
    @ ./reducedim.jl:317 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [7] mapreducedim!
    @ ./reducedim.jl:324 [inlined]
  [8] #sum!#852
    @ ./reducedim.jl:1034 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [10] sum!
    @ ./reducedim.jl:1034 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [12] #sum!#853
    @ ./reducedim.jl:1036 [inlined]
 [13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [14] sum!(r::AbstractArray, A::AbstractArray)
    @ Base ./reducedim.jl:1036 [inlined]
 [15] __added_bias_gradient(b::AbstractArray, Δ::Any)
    @ LuxLib ~/.julia/packages/LuxLib/JTAYi/src/utils.jl:153 [inlined]
 [16] __matmul_bias_partials
    @ ~/.julia/packages/LuxLib/JTAYi/src/impl/fused_dense.jl:78 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] #27
    @ ~/.julia/packages/LuxLib/JTAYi/src/impl/fused_dense.jl:47 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/packages/LuxLib/JTAYi/src/api/dense.jl:46 [inlined]
 [23] (::Zygote.Pullback{…})(Δ::Tuple{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/.julia/packages/LuxLib/JTAYi/src/api/dense.jl:38 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Nothing, FillArrays.Fill{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] Pullback
    @ ~/.julia/packages/Lux/LhwgF/src/layers/basic.jl:357 [inlined]
 [27] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [28] Pullback
    @ ~/.julia/packages/LuxCore/qeN7D/src/LuxCore.jl:175 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, FillArrays.Fill{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] Pullback
    @ ~/.julia/packages/Lux/LhwgF/src/layers/containers.jl:0 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, FillArrays.Fill{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [32] Pullback
    @ ~/.julia/packages/Lux/LhwgF/src/layers/containers.jl:496 [inlined]
 [33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [34] Pullback
    @ ~/Documents/git_repositories/julia/julia_random_stuff/zygote_bug/double_diff_zygote_bug.jl:14 [inlined]
 [35] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [36] #75
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91 [inlined]
 [37] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{FillArrays.Fill{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [38] gradient
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148 [inlined]
 [39] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{FillArrays.Fill{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [40] loss_function
    @ ~/Documents/git_repositories/julia/julia_random_stuff/zygote_bug/double_diff_zygote_bug.jl:14 [inlined]
 [41] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float64, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Tuple{Float64, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [43] compute_gradients(::AutoZygote, objective_function::typeof(loss_function), data::Matrix{…}, ts::Lux.Experimental.TrainState{…})
    @ LuxZygoteExt ~/.julia/packages/Lux/LhwgF/ext/LuxZygoteExt/training.jl:5
 [44] compute_gradients(::AutoZygote, ::Vararg{Any}; kwargs::@Kwargs{})
    @ Lux.Training ~/.julia/packages/Lux/LhwgF/src/contrib/contrib.jl:50
 [45] compute_gradients(::AutoZygote, ::Vararg{Any})
    @ Lux.Training ~/.julia/packages/Lux/LhwgF/src/contrib/contrib.jl:48
 [46] top-level scope
    @ ~/Documents/git_repositories/julia/julia_random_stuff/zygote_bug/double_diff_zygote_bug.jl:29
Some type information was truncated. Use `show(err)` to see complete types.

In Lux v0.5.38 it just works:

(previous_env) pkg> st
Status `~/Documents/git_repositories/julia/julia_random_stuff/zygote_bug/previous_env/Project.toml`
⌃ [b2108857] Lux v0.5.38
  [3bd65402] Optimisers v0.3.3
  [e88e6eb3] Zygote v0.6.70
Info Packages marked with ⌃ have new versions available and may be upgradable.

julia> Lux.Training.compute_gradients(vjp_rule, loss_function, sample_points, tstate_org)
((layer_1 = (weight = Float32[1.4885416; -1.3188757; … ; -1.542988; 1.7500782;;], bias = Float32[0.08510557; -0.09157323; … ; -0.12596396; 0.19119978;;]), layer_2 = (weight = Float32[0.28484756 -0.710527 … 1.945188 2.4889088], bias = Float32[0.8374626;;])), -1.5246360029540207, (), Lux.Experimental.TrainState{Chain{@NamedTuple{layer_1::Dense{true, typeof(sigmoid_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}}(Chain{@NamedTuple{layer_1::Dense{true, typeof(sigmoid_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}, Nothing}((layer_1 = Dense(1 => 8, sigmoid_fast), layer_2 = Dense(8 => 1)), nothing), (layer_1 = (weight = Float32[-0.036469776; -0.31327224; … ; 0.43015236; 0.6011922;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.40568522 -0.379589 … -0.46527746 0.5768912], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), (layer_1 = (weight = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.01, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0))

I don't see any mutating array in the loss_function. It looks like something changed in Lux.Training.compute_gradients that broke it.

avik-pal commented 4 months ago

Zygote doesn't do nested AD well. Lux 0.5.38 used to ship an older version of LuxLib which allowed nested Zygote calls (for very limited cases restricted to only Dense and Conv) but was considerably slower.

In the recent versions, we have significantly robust nested AD support but that requires some manual rewrite. See https://lux.csail.mit.edu/stable/manual/nested_autodiff for a detailed version. Your usecase can be solved with:

function loss_function(model, ps, st, sample_points)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    y_pred = smodel(sample_points)
    dy_pred = only(Zygote.gradient(sum ∘ smodel, sample_points))
    loss = sum(dy_pred .+ y_pred .^2 / 2)
    return loss, smodel.st, ()
end

Note that while it might seem like we are still doing Zygote.gradient(Zygote.gradient(...)) to compute the final gradients here, we are instead using a more optimized mixed mode AD combining ForwardDiff and Zygote. We recently had an extended discussion on discourse about this.

rmsrosa commented 4 months ago

Thanks for the quick reply!

I tried your suggestion of loss_function, but I am still getting an error.

julia> Lux.Training.compute_gradients(vjp_rule, loss_function, sample_points, tstate_org)
ERROR: ArgumentError: type does not have a definite number of fields

I guess I don't understand enough to find my way out of this. I should read the discourse discussion and the link about nested AD to see if I have a better understanding of this.

avik-pal commented 4 months ago

Are you using the latest version of Lux?

rmsrosa commented 4 months ago

Are you using the latest version of Lux?

Yes:

(current_env) pkg> st
Status `~/Documents/git-repositories/julia/julia_random_stuff/zygote_bug/current_env/Project.toml`
  [b2108857] Lux v0.5.58
  [3bd65402] Optimisers v0.3.3
  [e88e6eb3] Zygote v0.6.70
avik-pal commented 4 months ago

I will have a look at this today evening

rmsrosa commented 4 months ago

Thank you!

Now it works, and it is indeed much faster!