Closed YichengDWu closed 2 years ago
If I overload map
and I can get around it
julia> Base.map(f, x::ComponentArray, args...) = ComponentArray(map(f, getdata(x), args...), getaxes(x))
julia> map((x,y)->x+y, ca,[3,4])
ComponentVector{Int64}(a = 4, b = 6)
This could be dangerous tho...
Ok this is a bad idea. But the following idea is a good one
import Zygote: seed
using ForwardDiff: Dual
function Zygote.seed(x::ComponentArray, ::Val{N}, offset = 0) where N
dual = map(x, reshape(1:length(x), size(x))) do x, i
Dual(x, ntuple(j -> j+offset == i, Val(N)))
end
return ComponentArray(dual, getaxes(x))
end
This is closely related to #148
The behavior of map
should fall out of broadcast rules, rather than be overloaded explicitly (because this wouldn’t cover the case where the 2nd, 3rd, etc. argument is a ComponentArray
and the first isn’t. I could probably change this so the ComponentArray
“wins” here.
The current behaviour of map
is fine as it is the same with named tuples.
julia> ca = (a=1,b=2)
(a = 1, b = 2)
julia> map((x,y)->x+y, ca,1:2)
2-element Vector{Int64}:
2
4
I'm closing this since its more of Zygote issue
Related to #57