JuliaArrays / AxisArrays.jl

Performant arrays where each dimension can have a named axis with values
http://JuliaArrays.github.io/AxisArrays.jl/latest/
Other
200 stars 41 forks source link

WIP: Implement broadcasting with AxisArrays on Julia 0.7 #131

Open ajkeller34 opened 6 years ago

ajkeller34 commented 6 years ago

This PR proposes an implementation of broadcasting for AxisArrays that will be possible using Julia 0.7. I'm getting a bit ahead of myself because not all AxisArrays tests pass on 0.7, and I'm also aware that the new broadcasting API may continue to change (e.g. https://github.com/JuliaLang/julia/pull/25377). However, broadcasting is important enough for how I intend to use AxisArrays that I want to give an early demo, and also want to solicit some feedback before I sink too much time into this approach.

High-level description

Algorithm description

The following discussion relies upon understanding the new broadcasting API described in the interfaces section of the latest Julia docs. Broadcasting is intercepted after styles are combined, but before eltypes and indices are computed.

  1. combine_indices from any AxisArrays (but not other kinds of arrays) in the broadcasting operation. AxisArray axis names and values are returned from a new broadcast_indices method. As currently implemented, this demands exact equality of axis values, so tiny floating-point differences count. This returns a tuple of AxisArrays.Axis that we'll call axesAs.

  2. Provided that was successful, do broadcasting over all broadcast args using the underlying arrays (array.data if array is an AxisArray). Call the result broadcasted.

  3. Compare the axes in axesAs with the default_axes for broadcasted (which is not an AxisArray). We'll call the tuple of default axes defaxesBs. Note that length(axesAs) <= length(defaxesBs). Process these two tuples axesAs and defaxesBs taking pairs of elements axA, axB from each using Base.tail, etc.

3a. If the axis names match, then you need to see if you believe the axis from axA was originally a default axis. This PR makes the decision that if you have an axis like Axis{:row, <:Base.OneTo}, then it was a default axis. If so, return e.g. Axis{:row}(Base.OneTo(length(axB)) so that you resize the default axis to match the size required for broadcasted. If the values are not from Base.OneTo then it is not a default axis, and the arrays cannot be broadcasted.

3b. If the axis names don't match, then there's no need to worry about default axes, just return axA.

  1. Step 3 yields a tuple of axes. Wrap broadcasted into an AxisArray using the axes obtained from step 3. The number of axes you obtain from step 3 may be less than the number of dimensions of broadcasted, in which case the AxisArray constructor will use default_axes for the remainder.

Examples

See test/broadcast.jl, more tests/examples to come.

Relation to previous AxisArrays.jl issues and PRs concerning broadcasting

Issue 128

This PR satisfies what @omus considers an ideal solution in https://github.com/JuliaArrays/AxisArrays.jl/issues/128#issuecomment-340018520 (I've sanitized some deprecation warnings):

julia> A = AxisArray([1,2,3], Axis{:asdf}([1.0, 2.0, 5.0]))
3-element AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}:
 1
 2
 3

julia> A .* 2
3-element AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}:
 2
 4
 6

julia> A .== 2
3-element AxisArray{Bool,1,BitArray{1},Tuple{Axis{:asdf,Array{Float64,1}}}}:
 false
  true
 false

PR 54

This PR also doesn't care about argument order, which was a limitation in PR https://github.com/JuliaArrays/AxisArrays.jl/pull/54:

julia> 1 .+ A
3-element AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}:
 2
 3
 4

It also doesn't care if the eltypes are Real, another limitation in https://github.com/JuliaArrays/AxisArrays.jl/pull/54:

julia> AxisArray([1+im, 2+im]) .+ (3.0+4.5im)
2-element AxisArray{Complex{Float64},1,Array{Complex{Float64},1},Tuple{Axis{:row,Base.OneTo{Int64}}}}:
 4.0 + 5.5im
 5.0 + 5.5im

Note that broadcasting is not oblivious to the underlying storage order, as mentioned in the high-level description, and there are differing opinions on that [1] [2]. However, this PR is very conservative, in that you can do strictly more with broadcasting while preserving the AxisArray wrapper. If there were another PR that paid no attention to the underlying storage order / did auto alignment, I think you would again have strictly more functionality, for some sense of the word strictly :) I'm not sure how broadcasting should be treated when combining both AxisArrays and AbstractArrays in that case; there you kind of need to pay attention to the storage order.

Known limitations

  1. Not everything is inferable yet, trying to identify why.

  2. Some of the error messages are opaque when broadcasting doesn't work for AxisArray-specific reasons. I don't think this is insurmountable but it would require some more boiler-plate to fix.

  3. Axis info can get lost when using wrappers around AxisArrays, like with adjoint. This has been resolved as follows:

julia> A'
1×3 Adjoint{Int64,AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}}:
 1  2  3

julia> A' .+ [10,20,30]
3×3 AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:row,Base.OneTo{Int64}},Axis{:asdf,Array{Float64,1}}}}:
 11  12  13
 21  22  23
 31  32  33

julia> A' .+ AxisArray([10,20,30])
3×3 AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:row,Base.OneTo{Int64}},Axis{:asdf,Array{Float64,1}}}}:
 11  12  13
 21  22  23
 31  32  33

julia> A' .+ [10 20 30]
1×3 AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:row,Base.OneTo{Int64}},Axis{:asdf,Array{Float64,1}}}}:
 11  22  33

julia> A' .+ A'
1×3 Adjoint{Int64,AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}}:
 2  4  6

Note that as a consequence of requiring unique axis names for each dimension, A + A' fails. This is because the result array would have the same axis name for both column and row (:asdf). At first I wondered if Adjoint should really wrap AxisArrays like it does now, but that's actually consistent with this PR in that the underlying storage order is important in broadcasting. I think I'm fine with that— perhaps the README should say specifically that indexing can be oblivious to the storage order of the underlying array.

  1. Take a look at transpose(A):
julia> transpose(A)
1×3 AxisArray{Int64,2,Transpose{Int64,Array{Int64,1}},Tuple{Axis{:transpose,Base.OneTo{Int64}},Axis{:abc,Array{Float64,1}}}}:
 1  2  3

Probably AxisArrays should be updated to use the Transpose type that was introduced.

dkarrasch commented 5 years ago

I came across this package as I was looking for some "professional" solution instead of my own stupid hack, and it looks really good. It would be great to have AxisArray-preserving broadcasting, so what's the status of this PR?

lamorton commented 3 years ago

What's blocking this? Is there a fundamental problem, or maybe just some busywork with tests implementing tests? It seems like a solution to #156 and maybe #128.