meggart / DiskArrays.jl

Other
72 stars 13 forks source link

Enable broadcasted assignment with trailing singleton dimensions #141

Closed wkearn closed 7 months ago

wkearn commented 7 months ago

With regular Arrays, it is possible to do a broadcasted assignment to a destination with fewer dimensions when the trailing dimensions all have size 1:

dest = zeros(10,9)
src = rand(10,9,1,1)
dest .= src

@assert dest == src[:,:,1,1]

But if src is a DiskArray, this fails:

using DiskArrays

struct _DiskArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
    parent::A
    chunksize::NTuple{N,Int}
end
_DiskArray(a; chunksize=size(a)) = _DiskArray(a, chunksize)
DiskArrays.@implement_diskarray _DiskArray
Base.size(a::_DiskArray) = size(a.parent)
DiskArrays.haschunks(::_DiskArray) = DiskArrays.Chunked()
DiskArrays.eachchunk(a::_DiskArray) = DiskArrays.GridChunks(a, a.chunksize)
DiskArrays.readblock!(a::_DiskArray, aout, i::AbstractUnitRange...) = aout .= a.parent[i...]
DiskArrays.writeblock!(a::_DiskArray, v, i::AbstractUnitRange...) = view(a.parent, i...) .= v

data = rand(10,9,1,1)
src = _DiskArray(data)
dest = zeros(10,9)
dest .= src

@assert dest == data[:,:,1,1]

with the error:

ERROR: MethodError: no method matching splittuple()

Closest candidates are:
  splittuple(::Any, ::Any...)
   @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:127

Stacktrace:
  [1] maybeonerange(out::Tuple{UnitRange{Int64}, UnitRange{Int64}}, sizes::Tuple{Int64, Int64}, ranges::Tuple{})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:122
  [2] maybeonerange(out::Tuple{UnitRange{Int64}}, sizes::Tuple{Int64, Int64, Int64}, ranges::Tuple{UnitRange{Int64}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:123
  [3] maybeonerange(out::Tuple{}, sizes::NTuple{4, Int64}, ranges::Tuple{UnitRange{Int64}, UnitRange{Int64}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:123
  [4] maybeonerange(sizes::NTuple{4, Int64}, ranges::Tuple{UnitRange{Int64}, UnitRange{Int64}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:126
  [5] subsetarg(x::_DiskArray{Float64, 4, Array{Float64, 4}}, a::Tuple{UnitRange{Int64}, UnitRange{Int64}})
    @ Main ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:144
  [6] (::DiskArrays.var"#62#64"{Tuple{UnitRange{Int64}, UnitRange{Int64}}})(i::_DiskArray{Float64, 4, Array{Float64, 4}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:41
  [7] map
    @ ./tuple.jl:291 [inlined]
  [8] (::DiskArrays.var"#61#63"{Matrix{…}, Base.Broadcast.Broadcasted{…}})(cnow::Tuple{UnitRange{…}, UnitRange{…}})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:41
  [9] foreach(f::DiskArrays.var"#61#63"{Matrix{…}, Base.Broadcast.Broadcasted{…}}, itr::DiskArrays.GridChunks{2, Tuple{…}})
    @ Base ./abstractarray.jl:3094
 [10] copyto!(dest::Matrix{…}, bc::Base.Broadcast.Broadcasted{…})
    @ DiskArrays ~/.julia/packages/DiskArrays/1rcQi/src/broadcast.jl:38
 [11] materialize!
    @ Base.Broadcast ./broadcast.jl:914 [inlined]
 [12] materialize!(dest::Matrix{…}, bc::Base.Broadcast.Broadcasted{…})
    @ Base.Broadcast ./broadcast.jl:911
 [13] top-level scope
    @ REPL[16]:1

This seems to occur in the case where there are trailing singleton dimensions because the recursive maybeonerange does not reach the correct base case. This can be fixed by adding the following cases

maybeonerange(out, sizes, ::Tuple{}) = out
maybeonerange(out, ::Tuple{}, ::Tuple{}) = out

where the first one is the case that we reach with the trailing singleton dimensions. The second one is needed to resolve method ambiguities. I think by the time the call gets to this point, the array shapes have all been checked, so it should be safe to do this, but I could be wrong.

This pull request makes that change and adds a test.

rafaqz commented 7 months ago

From the error it looks like this didnt work previously on 1.6? (julia throws an error before we hit the DiskArrays.jl error)

We may need to only test this after 1.9/1.10 or whenever it was implemented.

wkearn commented 7 months ago

Looks like this is the commit that added this broadcasting behavior, so the test should work for versions 1.7 and higher. I've wrapped the testset in a check to run it only on these versions. but let me know if you would rather do it a different way.

CI is failing on nightly, but I think that is unrelated to this change. See #142

rafaqz commented 7 months ago

Thanks