greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
528 stars 63 forks source link

speed up greta:::tf_extract() and greta:::tf_replace() #309

Open goldingn opened 5 years ago

goldingn commented 5 years ago

greta's extraction (e.g. a <- x[2]) and replacement (e.g. x[2] <- 0) syntax uses the internal tensorflow functions tf_extract() and tf_replace() to do the operations on tensors, with a shim to map R's extract/replace syntax onto TensorFlow's. This shim doesn't always use the most efficient operations for common extraction methods.

For example: x[2, ] could use tf$slice(), which might be more efficient than the current general approach or reshaping to a vector, using tf$gather() and then reshaping to a matrix.

x[2, ] <- 0 could use tf$tensor_scatter_nd_update(), which would probably be much more efficient than the current (greta:::tf_recombine()) approach of flattening the vector, breaking it up into vectors, replacing some, and then recombining them with tf$concat() before reshaping the vector into a matrix 😓.

This would particularly help when people write for loops to that alter elements in matrices (e.g. for timeseries models), and reduce the need for nasty hacks like storing the iterated components in lists.

goldingn commented 5 years ago

It would also be awesome if this rewrite could enable indexing with greta arrays.

goldingn commented 4 years ago

Here's a speed comparison with a custom extract (still using tf$gather()) on numeric rows, for a relatively realistic model on a large dataset. Takes around 2/3 the time with the new operation, which could be wrapped up in [ quite easily.

library(tensorflow)
library(greta)
op <- .internals$nodes$constructors$op

tf_grab_rows <- function(x, rows) {
  x_t <- tf$transpose(x, c(1L, 2L, 0L))
  z_t <- tf$gather(x_t, rows)
  z <- tf$transpose(z_t, c(2L, 0L, 1L))
  z
}

grab_rows <- function (x, row_vec) {
  dim <- dim(x)
  dim[1] <- length(row_vec)
  op("grab_rows",
     x,
     dim = dim,
     operation_args = list(rows = as.integer(row_vec - 1)),
     tf_operation = "tf_grab_rows")
}

# make a big (7500 rows) fake dataset with lots of hierarchical levels
iris2 <- do.call(rbind, replicate(50, iris, simplify = FALSE))
X <- model.matrix(~ Sepal.Width * Petal.Length * Petal.Width, data = iris2)
n_levels <- 500
idx <- sample.int(n_levels, nrow(X), replace = TRUE)
y <- iris2$Sepal.Length
k <- ncol(X)

# current way
beta <- normal(0, 1, dim = c(n_levels, k))
sigma <- normal(0, 1, truncation = c(0, Inf))
eta <- rowSums(X * beta[idx, ])
distribution(y) <- normal(eta, sigma)
m <- model(beta, sigma)
system.time(draws <- mcmc(m))

# new way
beta <- normal(0, 1, dim = c(n_levels, k))
sigma <- normal(0, 1, truncation = c(0, Inf))
eta <- rowSums(X * grab_rows(beta, idx))
distribution(y) <- normal(eta, sigma)
m <- model(beta, sigma)
system.time(draws <- mcmc(m))
njtierney commented 2 years ago

Relevant part of greta code:

https://github.com/greta-dev/greta/blob/master/R/extract_replace_combine.R#L81-L258

goldingn commented 1 year ago

Here's a reprex for a particularly slow bit of replacement code. This is a simplified version of a problem that come up in a real model recently. It is sloooow when doing replacement

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# this function takes forever with moderately large n
dodgy_subsetting <- function(n) {

  valid <- matrix(sample(c(TRUE, FALSE), n^2, replace = TRUE), n, n)
  n_valid <- sum(valid)

  a <- normal(0, 1)
  y_valid <- ones(n_valid) * a
  y <- zeros(n, n)
  y[valid] <- y_valid

  system.time(
    calculate(y, nsim = 1)
  )

}

# OK
dodgy_subsetting(n = 10)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
#>    user  system elapsed 
#>   0.686   0.031   0.717

# slooow
dodgy_subsetting(n = 50)
#>    user  system elapsed 
#>  15.545   0.245  15.928

# this non-replacement version achieves the same result. It should not be
# computationally any more complicated, and unlike the replacement version it is
# very fast for reasonably large n. We should expect the replacement version to
# have a similar speed.
mask_version <- function(n) {

  valid <- matrix(sample(c(TRUE, FALSE), n^2, replace = TRUE), n, n)

  valid_mask <- valid
  valid_mask[] <- as.numeric(valid_mask[])

  a <- normal(0, 1)
  y <- ones(n, n) * a
  y <- y * valid_mask

  system.time(
    calculate(y, nsim = 1)
  )

}

# fast
mask_version(10)
#>    user  system elapsed 
#>   0.092   0.002   0.094

# fast
mask_version(50)
#>    user  system elapsed 
#>   0.105   0.002   0.108

# fast
mask_version(1000)
#>    user  system elapsed 
#>   0.163   0.040   0.145

Created on 2022-11-22 by the reprex package (v2.0.0)