Closed goldingn closed 3 months ago
Hmmm, looking at the blame for the relevant section:
https://github.com/greta-dev/greta/blame/master/R/operators.R#L157
It looks like that code hasn't changed for the past 3 years?
I confirm that I get the same error (on branch #463), I've also included the relevant checking part of the code at the bottom, does that return what you would expect?
library(greta)
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
x <- matrix(1, 2, 3)
y <- rep(1, 3)
# these three work
x %*% y
#> [,1]
#> [1,] 3
#> [2,] 3
x %*% as_data(y)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> Loaded Tensorflow version 1.14.0
#> ✓ Initialising python and checking dependencies ... done!
#>
#> greta array (operation)
#>
#> [,1]
#> [1,] ?
#> [2,] ?
as_data(x) %*% as_data(y)
#> greta array (operation)
#>
#> [,1]
#> [1,] ?
#> [2,] ?
# this errors
as_data(x) %*% y
#> Error: only two-dimensional <greta_array>s can be matrix-multiplied
#> dimensions recorded were 2 and 3
dim(as_data(x))
#> [1] 2 3
length(dim(as_data(x)))
#> [1] 2
dim(y)
#> NULL
length(dim(y))
#> [1] 0
# code that does the checking is:
# if (length(dim(x)) != 2 | length(dim(y)) != 2) {
Created on 2021-11-25 by the reprex package (v2.0.1)
%*%
is written as:
`%*%` <- function(x, y) { # nolint
# if y is a greta array, coerce x before dispatch
if (inherits(y, "greta_array") & !inherits(x, "greta_array")) {
as_data(x) %*% y
} else {
UseMethod("%*%", x)
}
}
Which gives us
library(greta)
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
x <- matrix(1, 2, 3)
y <- rep(1, 3)
# these three work
x %*% y
#> [,1]
#> [1,] 3
#> [2,] 3
res_1 <- x %*% as_data(y)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
res_2 <- as_data(x) %*% as_data(y)
res_3 <- as_data(x) %*% y
#> Error: only two-dimensional <greta_array>s can be matrix-multiplied
#> dimensions recorded were 2 and 3
res_1
#> greta array (operation)
#>
#> [,1]
#> [1,] ?
#> [2,] ?
calculate(res_1, nsim = 1)
#> $res_1
#> , , 1
#>
#> [,1] [,2]
#> [1,] 3 3
res_2
#> greta array (operation)
#>
#> [,1]
#> [1,] ?
#> [2,] ?
calculate(res_2, nsim = 1)
#> $res_2
#> , , 1
#>
#> [,1] [,2]
#> [1,] 3 3
res_3
#> Error in eval(expr, envir, enclos): object 'res_3' not found
calculate(res_3, nsim = 1)
#> Error in eval(expr, envir, enclos): object 'res_3' not found
Created on 2024-05-13 with reprex v2.1.0
However I think this won't coerce y to be a greta array, and that won't happen on dispact since `%*%`.greta_array
can only dispatch on the first argument, which would be x, and we want to control for cases where y
is a greta array but not x, I think?
So I think it needs to be:
#' @rdname overloaded
#' @export
`%*%` <- function(x, y) { # nolint
# if y is a greta array, coerce x before dispatch
if (inherits(y, "greta_array") & !inherits(x, "greta_array")) {
as_data(x) %*% y
# if y is not a greta array and x is, coerce y before dispatch
} else if (!inherits(y, "greta_array") & inherits(x, "greta_array")){
x %*% as_data(y)
} else {
UseMethod("%*%", x)
}
}
Doing this we get:
library(greta)
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
x <- matrix(1, 2, 3)
y <- rep(1, 3)
# these three work
x %*% y
#> [,1]
#> [1,] 3
#> [2,] 3
res_1 <- x %*% as_data(y)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
res_2 <- as_data(x) %*% as_data(y)
res_3 <- as_data(x) %*% y
res_1
#> greta array (operation)
#>
#> [,1]
#> [1,] ?
#> [2,] ?
calculate(res_1, nsim = 1)
#> $res_1
#> , , 1
#>
#> [,1] [,2]
#> [1,] 3 3
res_2
#> greta array (operation)
#>
#> [,1]
#> [1,] ?
#> [2,] ?
calculate(res_2, nsim = 1)
#> $res_2
#> , , 1
#>
#> [,1] [,2]
#> [1,] 3 3
res_3
#> greta array (operation)
#>
#> [,1]
#> [1,] ?
#> [2,] ?
calculate(res_3, nsim = 1)
#> $res_3
#> , , 1
#>
#> [,1] [,2]
#> [1,] 3 3
Created on 2024-05-13 with reprex v2.1.0
To check this works properly, I have added a test like so:
test_that("%*% works when one is a non-greta array", {
x <- matrix(1, 2, 3)
y <- rep(1, 3)
expect_snapshot(x %*% y)
expect_snapshot(x %*% as_data(y))
expect_snapshot(as_data(x) %*% y)
expect_snapshot(as_data(x) %*% as_data(y))
res_1 <- x %*% as_data(y)
res_2 <- as_data(x) %*% y
res_3 <- as_data(x) %*% as_data(y)
expect_snapshot(calculate(res_1, nsim = 1))
expect_snapshot(calculate(res_2, nsim = 1))
expect_snapshot(calculate(res_3, nsim = 1))
})
Which should first capture that this doesn't error and gives the appropriate greta array (except in the first instance of x %*% y), and then check that the calculated result is correct.
Let me know if this sounds right, @goldingn
When performing a matrix multiply between a greta array matrix and an R vector, I would expect (and think it was the previous behaviour) that the check on dimensions would happen only after attempting to coerce to greta arrays.
But it appears to do the dimension check on the vector before coercing to a greta array, so matrix-multiplying a greta array matrix by an R vector (of the correct dimensions) errors incorrectly and needs to be explicitly coerced toa. greta array by the user:
I'm guessing that this might happen in other places too, if there was a fundamental change to how this checking is done