Open goldingn opened 5 years ago
It would also be awesome if this rewrite could enable indexing with greta arrays.
Here's a speed comparison with a custom extract (still using tf$gather()
) on numeric rows, for a relatively realistic model on a large dataset. Takes around 2/3 the time with the new operation, which could be wrapped up in [
quite easily.
library(tensorflow)
library(greta)
op <- .internals$nodes$constructors$op
tf_grab_rows <- function(x, rows) {
x_t <- tf$transpose(x, c(1L, 2L, 0L))
z_t <- tf$gather(x_t, rows)
z <- tf$transpose(z_t, c(2L, 0L, 1L))
z
}
grab_rows <- function (x, row_vec) {
dim <- dim(x)
dim[1] <- length(row_vec)
op("grab_rows",
x,
dim = dim,
operation_args = list(rows = as.integer(row_vec - 1)),
tf_operation = "tf_grab_rows")
}
# make a big (7500 rows) fake dataset with lots of hierarchical levels
iris2 <- do.call(rbind, replicate(50, iris, simplify = FALSE))
X <- model.matrix(~ Sepal.Width * Petal.Length * Petal.Width, data = iris2)
n_levels <- 500
idx <- sample.int(n_levels, nrow(X), replace = TRUE)
y <- iris2$Sepal.Length
k <- ncol(X)
# current way
beta <- normal(0, 1, dim = c(n_levels, k))
sigma <- normal(0, 1, truncation = c(0, Inf))
eta <- rowSums(X * beta[idx, ])
distribution(y) <- normal(eta, sigma)
m <- model(beta, sigma)
system.time(draws <- mcmc(m))
# new way
beta <- normal(0, 1, dim = c(n_levels, k))
sigma <- normal(0, 1, truncation = c(0, Inf))
eta <- rowSums(X * grab_rows(beta, idx))
distribution(y) <- normal(eta, sigma)
m <- model(beta, sigma)
system.time(draws <- mcmc(m))
Relevant part of greta code:
https://github.com/greta-dev/greta/blob/master/R/extract_replace_combine.R#L81-L258
Here's a reprex for a particularly slow bit of replacement code. This is a simplified version of a problem that come up in a real model recently. It is sloooow when doing replacement
library(greta)
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
# this function takes forever with moderately large n
dodgy_subsetting <- function(n) {
valid <- matrix(sample(c(TRUE, FALSE), n^2, replace = TRUE), n, n)
n_valid <- sum(valid)
a <- normal(0, 1)
y_valid <- ones(n_valid) * a
y <- zeros(n, n)
y[valid] <- y_valid
system.time(
calculate(y, nsim = 1)
)
}
# OK
dodgy_subsetting(n = 10)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
#> user system elapsed
#> 0.686 0.031 0.717
# slooow
dodgy_subsetting(n = 50)
#> user system elapsed
#> 15.545 0.245 15.928
# this non-replacement version achieves the same result. It should not be
# computationally any more complicated, and unlike the replacement version it is
# very fast for reasonably large n. We should expect the replacement version to
# have a similar speed.
mask_version <- function(n) {
valid <- matrix(sample(c(TRUE, FALSE), n^2, replace = TRUE), n, n)
valid_mask <- valid
valid_mask[] <- as.numeric(valid_mask[])
a <- normal(0, 1)
y <- ones(n, n) * a
y <- y * valid_mask
system.time(
calculate(y, nsim = 1)
)
}
# fast
mask_version(10)
#> user system elapsed
#> 0.092 0.002 0.094
# fast
mask_version(50)
#> user system elapsed
#> 0.105 0.002 0.108
# fast
mask_version(1000)
#> user system elapsed
#> 0.163 0.040 0.145
Created on 2022-11-22 by the reprex package (v2.0.0)
greta's extraction (e.g.
a <- x[2]
) and replacement (e.g.x[2] <- 0
) syntax uses the internal tensorflow functionstf_extract()
andtf_replace()
to do the operations on tensors, with a shim to map R's extract/replace syntax onto TensorFlow's. This shim doesn't always use the most efficient operations for common extraction methods.For example:
x[2, ]
could usetf$slice()
, which might be more efficient than the current general approach or reshaping to a vector, usingtf$gather()
and then reshaping to a matrix.x[2, ] <- 0
could usetf$tensor_scatter_nd_update()
, which would probably be much more efficient than the current (greta:::tf_recombine()
) approach of flattening the vector, breaking it up into vectors, replacing some, and then recombining them withtf$concat()
before reshaping the vector into a matrix 😓.This would particularly help when people write for loops to that alter elements in matrices (e.g. for timeseries models), and reduce the need for nasty hacks like storing the iterated components in lists.