r-lib / rray

Simple Arrays
https://rray.r-lib.org
GNU General Public License v3.0
130 stars 12 forks source link

`rray_bind()` `axes` generalization #178

Open DavisVaughan opened 5 years ago

DavisVaughan commented 5 years ago

This is definitely a "future" improvement, but rray_bind() technically can bind over multiple axes, which would result in a block diagonal binding, like in this example. This is essentially a generalization of rray_bind(), but I'm not sure how useful it would actually be.

https://stackoverflow.com/questions/17495841/block-diagonal-binding-of-matrices

juangomezduaso commented 5 years ago

In matrix algebra this is known as the "direct sum" operator, and has its uses. Beyond 2D I dont know if it is usefull in tensor computing (phisics, engennering) or in neural networks etc. In my mostly OLAP oriented experience, I have never needed it. The axes bind generalization would be even more general (having "MARGIN" dims as well) Neither would be too complicated to implement, but perhaps direct sum would be easier to grasp by users This is a quick example of direct sum in R (without dimnames management):

library(rray)
library(purrr)
diag_bind <- function(..., padvalue=0){

  arglist<-list(...)
  nr <-length(arglist)
  nd <- vctrs::vec_dims(arglist[[1]])

  # Check all dims equal
  stopifnot(all(map(arglist,vctrs::vec_dims) == nd))

  # calculate coordinates of each chunk end.
  dims_matrix <- as.matrix(rray_cbind(!!!map(arglist,dim))) #dims*args
  coords <- apply(dims_matrix,1,cumsum)        # args*dims
  coords <- rbind(rep(0,nd),coords)

  # Indexes of the ranges
  indexes <-function(i) map(seq_len(nd),
                            function(j)seq(coords[i,j]+1,coords[i+1,j]))

  # Generate background result
  lastcoord <- coords[nrow(coords),]
  result <- rray(rep(padvalue, prod(lastcoord)),lastcoord)

  #  Overwrite each rray over the background
  for(i in seq_along(arglist)) {
    rray_subset(result, !!!indexes(i)) <- arglist[[i]] 
  }
  result
}
r1 <- rray(1:6,c(3,2,1))
r2 <- rray(11:18,c(2,2,2))
r3 <- rray(21, c(1,1,1))
diag_bind(r1,r2,r3)
#> <rray<dbl>[,5,4][120]>
#> , , 1
#> 
#>      [,1] [,2] [,3] [,4] [,5]
#> [1,]    1    4    0    0    0
#> [2,]    2    5    0    0    0
#> [3,]    3    6    0    0    0
#> [4,]    0    0    0    0    0
#> [5,]    0    0    0    0    0
#> [6,]    0    0    0    0    0
#> 
#> , , 2
#> 
#>      [,1] [,2] [,3] [,4] [,5]
#> [1,]    0    0    0    0    0
#> [2,]    0    0    0    0    0
#> [3,]    0    0    0    0    0
#> [4,]    0    0   11   13    0
#> [5,]    0    0   12   14    0
#> [6,]    0    0    0    0    0
#> 
#> , , 3
#> 
#>      [,1] [,2] [,3] [,4] [,5]
#> [1,]    0    0    0    0    0
#> [2,]    0    0    0    0    0
#> [3,]    0    0    0    0    0
#> [4,]    0    0   15   17    0
#> [5,]    0    0   16   18    0
#> [6,]    0    0    0    0    0
#> 
#> , , 4
#> 
#>      [,1] [,2] [,3] [,4] [,5]
#> [1,]    0    0    0    0    0
#> [2,]    0    0    0    0    0
#> [3,]    0    0    0    0    0
#> [4,]    0    0    0    0    0
#> [5,]    0    0    0    0    0
#> [6,]    0    0    0    0   21

Created on 2019-05-20 by the reprex package (v0.2.1)

juangomezduaso commented 5 years ago

A related function to rray_bind() (axes version) is the xtensor xt::pad function, which I dont know whether you plan to include in rray later on. I still think that bind_axes() would be a complicated function to use, but it allows powerfull consstructs. For instance, it can implement xt::pad. (though you probably would do them inversely )

library(rray)
library(purrr)
library(vctrs)
###### utility
assign_on_axes <-function(x,axes,target){
  stopifnot(length(axes) == length(x))
  target[axes]<- x
  target
}
################# BIND AXES #########
bind_axes <- function(..., axes=NULL, padvalue=0){
  arglist<-list(...)
  nr <-length(arglist)
  na <-length(axes)
  nd <- vec_dims(arglist[[1]])
  if(is.null(axes)) axes <- 1:nd
  margin_dim <- vec_dim(arglist[[1]])[-axes] 

  # Check all dims and MARGINs equal
  stopifnot(all(map_int(arglist,vec_dims) == nd))
  stopifnot(all(map_lgl(arglist[-1],function(r)all(vec_dim(r)[-axes] == margin_dim))))

  # calculate coordinates of each chunk end.
  dims_matrix <- rray_cbind(!!!map(arglist,dim)) #dims*args
  coords <- apply(dims_matrix[axes,,drop=FALSE],1,cumsum)        # args*dims
  coords <- rbind(rep(0,na),coords)

  # Indexes of the ranges. 
  indexes <-function(i) assign_on_axes(
    map(seq_len(na),function(j)seq(coords[i,j]+1,coords[i+1,j])),
    axes,
    rep(TRUE,nd) )

  # Generate background result
  lastcoord <- coords[nrow(coords),]
  result <- rray(rep(padvalue, prod(c(lastcoord,margin_dim))),
                 assign_on_axes(lastcoord,axes,vec_dim(arglist[[1]])) )

  #  Overwrite each rray over the background
  for(i in seq_along(arglist)) {
    rray_subset(result, !!!indexes(i)) <- arglist[[i]] 
  }
  result
}
r1 <- rray(1:6,c(3,2,1))
r2 <- rray(11:18,c(2,2,2))
r3 <- rray(21:22, c(1,2,1))

bind_axes(r1,r2,r3, axes=c(1,3))
#> <rray<dbl>[,2,4][48]>
#> , , 1
#> 
#>      [,1] [,2]
#> [1,]    1    4
#> [2,]    2    5
#> [3,]    3    6
#> [4,]    0    0
#> [5,]    0    0
#> [6,]    0    0
#> 
#> , , 2
#> 
#>      [,1] [,2]
#> [1,]    0    0
#> [2,]    0    0
#> [3,]    0    0
#> [4,]   11   13
#> [5,]   12   14
#> [6,]    0    0
#> 
#> , , 3
#> 
#>      [,1] [,2]
#> [1,]    0    0
#> [2,]    0    0
#> [3,]    0    0
#> [4,]   15   17
#> [5,]   16   18
#> [6,]    0    0
#> 
#> , , 4
#> 
#>      [,1] [,2]
#> [1,]    0    0
#> [2,]    0    0
#> [3,]    0    0
#> [4,]    0    0
#> [5,]    0    0
#> [6,]   21   22

##### xt::pad (simple version limited to insert before or after, but not both at a time)
xtpad <- function(x, offsets, axes, padvalue = 0, before = TRUE){
  paddim <-  assign_on_axes(offsets,axes,vec_dim(x))
  if( before) {
    bind_axes(rray(padvalue, paddim),x, axes=axes)
  } else {
    bind_axes(x,rray(padvalue, paddim), axes=axes)
  }
}
xtpad(r2,4:3,c(1,3))
#> <rray<dbl>[,2,5][60]>
#> , , 1
#> 
#>      [,1] [,2]
#> [1,]    0    0
#> [2,]    0    0
#> [3,]    0    0
#> [4,]    0    0
#> [5,]    0    0
#> [6,]    0    0
#> 
#> , , 2
#> 
#>      [,1] [,2]
#> [1,]    0    0
#> [2,]    0    0
#> [3,]    0    0
#> [4,]    0    0
#> [5,]    0    0
#> [6,]    0    0
#> 
#> , , 3
#> 
#>      [,1] [,2]
#> [1,]    0    0
#> [2,]    0    0
#> [3,]    0    0
#> [4,]    0    0
#> [5,]    0    0
#> [6,]    0    0
#> 
#> , , 4
#> 
#>      [,1] [,2]
#> [1,]    0    0
#> [2,]    0    0
#> [3,]    0    0
#> [4,]    0    0
#> [5,]   11   13
#> [6,]   12   14
#> 
#> , , 5
#> 
#>      [,1] [,2]
#> [1,]    0    0
#> [2,]    0    0
#> [3,]    0    0
#> [4,]    0    0
#> [5,]   15   17
#> [6,]   16   18
xtpad(r2,1,2,before=F)
#> <rray<dbl>[,3,2][12]>
#> , , 1
#> 
#>      [,1] [,2] [,3]
#> [1,]   11   13    0
#> [2,]   12   14    0
#> 
#> , , 2
#> 
#>      [,1] [,2] [,3]
#> [1,]   15   17    0
#> [2,]   16   18    0

Created on 2019-05-21 by the reprex package (v0.2.1)