jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
286 stars 34 forks source link

Incorrect gradient type. #235

Closed YichengDWu closed 4 months ago

YichengDWu commented 8 months ago

In the following code, shouldn't grad be of the same type as x?

julia> using ComponentArrays, Zygote

julia> ps = ComponentArray(weight=rand(3,3))
ComponentVector{Float64}(weight = [0.8258011674811379 0.44981527211623784 0.5507199474580698; 0.28250960116858315 0.42697555461561054 0.22836834345488244; 0.9092739196971233 0.6566982028049392 0.14458718099508994])

julia> grad = Zygote.gradient(x->sum(x.weight), ps)[1]
ComponentVector{Float64, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())),)}}}, Tuple{Axis{(weight = ViewAxis(1:9, ShapedAxis((3, 3), NamedTuple())),)}}}(weight = [1.0 1.0 1.0; 1.0 1.0 1.0; 1.0 1.0 1.0])
jonniedie commented 8 months ago

Yeah, it should. That's weird. I might get a chance to look at it this weekend.

avik-pal commented 4 months ago
julia> ps = ComponentArray(weight=rand(3,3))
ComponentVector{Float64}(weight = [0.3364916316745368 0.16747972646975107 0.961114830406715; 0.7662896569815121 0.21684158229208028 0.8580540237329293; 0.13361739374272052 0.6107957888988388 0.27262419344861877])

julia> grad = Zygote.gradient(x->sum(x.weight), ps)[1]
ComponentVector{Float64}(weight = [1.0 1.0 1.0; 1.0 1.0 1.0; 1.0 1.0 1.0])

This was fixed with the ProjectTo PR.