JuliaLang / julia

The Julia Programming Language
https://julialang.org/
MIT License
45.41k stars 5.45k forks source link

Feature Proposal: specialize `cat` on the last dim #50405

Open ctarn opened 1 year ago

ctarn commented 1 year ago

Hi, I would like to open the issue to see whether we are interested in specializing cat on the last dim.

Currently, we have specialized cat on 1st and 2nd dim as vcat and hcat respectively, and reduce on vcat and hcat have their own implementations for better performance.

 [4] reduce(::typeof(vcat), A::AbstractVector{<:AbstractVecOrMat})
     @ abstractarray.jl:1703
...
 [6] reduce(::typeof(hcat), A::AbstractVector{<:AbstractVecOrMat})
     @ abstractarray.jl:1706
...

Additionally, we also have stack for similar purposes, but stack requires that all items have the same size. (https://github.com/JuliaLang/julia/issues/21672, https://github.com/JuliaLang/julia/pull/43334)

In my own code base, I call it scat (stack-like cat), and simply define it as below:

scat(X...) = cat(X...; dims=ndims(first(X)))

"""
example:
julia> scat(rand(2, 3, 4, 5), rand(2, 3, 4, 6))
2×3×4×11 Array{Float64, 4}:
[:, :, 1, 1] =
 0.134323  0.369241  0.178764
 0.713772  0.323157  0.903367

[:, :, 2, 1] =
 0.646836   0.142949  0.234398
 0.0158574  0.874712  0.65062

...

"""

An extra method for reduce for better performance is preferred:

reduce(::typeof(scat), A) = ...
Seelengrab commented 1 year ago

This may need a bikeshed on the name...

mcabbott commented 1 year ago

Note also that "stack-like" seems misleading, in that this uses dims=ndims(x) where stack would use dims=ndims(x)+1. It's along the last dimension, rather than always a new dimension.

julia> x4 = rand(3, 4, 5, 6);

julia> scat(x4, x4) |> size
(3, 4, 5, 12)

julia> stack((x4, x4)) |> size
(3, 4, 5, 6, 2)
ctarn commented 1 year ago

Some names listed as below.

  1. scat: stack-like cat
  2. tcat: cat on the trunk dimension
  3. lcat: cat on leading/last dimension
  4. mcat: cat on major/main dimension, or merging cat
  5. bcat: cat as batches
  6. ccat: cat as chunks
  7. pcat: cat on the primary dimension
  8. ocat: cat on the outermost dimension.

Not recommended:

  1. dcat: the third dimension is commonly referred as depth
  2. ncat: (not recommended since the existing hvncat)
  3. vcat
  4. hcat