jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
286 stars 34 forks source link

Construction of ComponentArray inside of AD/Zygote #176

Open frankschae opened 1 year ago

frankschae commented 1 year ago

I want to compute the gradient of a loss function with respect to a ComponentArray. In the loss function, I need to reconstruct a ComponentArray. Based on @jonniedie reply https://github.com/jonniedie/ComponentArrays.jl/issues/126#issuecomment-1141580528, I tried

function my_sum(v)
    ax = getaxes(v)
    @unpack x, y = v
    ca = ComponentArray([x..., y...], ax)
    return sum(ca.x + ca.y)
end

Zygote.gradient(my_sum, ComponentArray(x=[0.0], y=[0.0]))

which fails with

ERROR: ArgumentError: indexed assignment with a single value to possibly many locations is not supported; perhaps use broadcasting `.=` instead?
Stacktrace:
  [1] setindex_shape_check(::ChainRulesCore.Tangent{Any, Tuple{Float64}}, ::Int64)
    @ Base ./indices.jl:261
  [2] _unsafe_setindex!(#unused#::IndexLinear, A::Vector{Float64}, x::ChainRulesCore.Tangent{Any, Tuple{Float64}}, I::UnitRange{Int64})
    @ Base ./multidimensional.jl:939
  [3] _setindex!
    @ ./multidimensional.jl:930 [inlined]
  [4] setindex!
    @ ./abstractarray.jl:1344 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/ComponentArrays/EjZNJ/src/array_interface.jl:0 [inlined]
  [6] _setindex!(x::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(x = 1:1, y = 2:2)}}}, v::ChainRulesCore.Tangent{Any, Tuple{Float64}}, idx::Val{:y})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/EjZNJ/src/array_interface.jl:129
  [7] setproperty!
    @ ~/.julia/packages/ComponentArrays/EjZNJ/src/namedtuple_interface.jl:17 [inlined]
  [8] (::ComponentArrays.var"#getproperty_adjoint#87"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(x = 1:1, y = 2:2)}}}, Symbol})(Δ::ChainRulesCore.Tangent{Any, Tuple{Float64}})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/EjZNJ/src/compat/chainrulescore.jl:4
  [9] ZBack
    @ ~/.julia/packages/Zygote/PD12J/src/compiler/chainrules.jl:206 [inlined]
 [10] Pullback
    @ ~/.julia/packages/UnPack/EkESO/src/UnPack.jl:34 [inlined]
 [11] (::typeof(∂(unpack)))(Δ::Tuple{Float64})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [12] macro expansion
    @ ~/.julia/packages/UnPack/EkESO/src/UnPack.jl:101 [inlined]
 [13] Pullback

pointing to the @unpack call. @avik-pal noted that it also happens even without the @unpack

function my_sum(v)
    ax = getaxes(v)
    ca = ComponentArray([v.x..., v.y...], ax)
    return sum(ca.x + ca.y)
end

Zygote.gradient(my_sum, ComponentArray(x=[0.0], y=[0.0]))

but is resolved by using vcat

function my_sum(v)
    ax = getaxes(v)
    @unpack x, y = v
    ca = ComponentArray(vcat(x,y), ax)
    return sum(ca.x + ca.y)
end

The issue seems to be that \Delta is a Tuple{Float64} in https://github.com/jonniedie/ComponentArrays.jl/blob/cbb24ef7156d18f1576ea48d7ae42023cc5bfa70/src/compat/chainrulescore.jl#L4 for splatting.

kaandocal commented 1 year ago

Constructing an array in general fails with Zygote:

using Zygote
using ComponentArrays

Zygote.gradient(x -> ComponentArray(a = [5])[1], [0.])

gives

ERROR: Mutating arrays is not supported -- called push!(Vector{Any}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:86
  [3] (::Zygote.var"#397#398"{Vector{Any}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:105
  [4] (::Zygote.var"#2508#back#399"{Zygote.var"#397#398"{Vector{Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] Pullback
    @ ./namedtuple.jl:309 [inlined]
  [6] (::typeof(∂(merge)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:161 [inlined]
  [8] (::typeof(∂(make_idx)))(Δ::Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:147 [inlined]
 [10] (::typeof(∂(make_carray_args)))(Δ::Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:139 [inlined]
 [12] (::typeof(∂(make_carray_args)))(Δ::Tuple{Vector{Float64}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:63 [inlined]
 [14] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:66 [inlined]
 [15] (::typeof(∂(#ComponentArray#21)))(Δ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1:1,)}}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:66 [inlined]
 [17] (::typeof(∂(Type##kw)))(Δ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1:1,)}}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./REPL[4]:1 [inlined]
 [19] (::typeof(∂(#3)))(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#60#61"{typeof(∂(#3))})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [21] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [22] top-level scope
    @ REPL[4]:1

It seems that to make CA work with Zygote it must entirely avoid mutating arrays (even appending to arrays)...

Yuan-Ru-Lin commented 1 year ago

Is there any chance to resolve the above error by using Zygote.Buffer as illustrated in here?