kadyb / OGH2022

OpenGeoHub 2022 Workshop: "Benchmarking R and Python for spatial data processing"
https://kadyb.github.io/OGH2022/
6 stars 0 forks source link

Spatial prediction speed #2

Open kadyb opened 2 years ago

kadyb commented 2 years ago
library("stars")
library("terra")
library("mclust")

## this function makes predict on data.frame and returns
## vector of number of pixels length
df_predict = function(df) {
  na_ids = which(is.na(df), arr.ind = TRUE)
  na_ids = unique(na_ids[, 1])
  pred = predict(mdl, df[-na_ids, ])$classification
  full_pred = rep(NA_integer_,  length = nrow(df))
  full_pred[-na_ids] = pred
  return(full_pred)
}

f = "raster.tif"
tmp = tempfile(fileext = ".tif")
bands = paste0("B", c(1:7, 9:11))

##### stars #####
ras = read_stars(f, proxy = FALSE)
df = as.data.frame(split(ras))
colnames(df)[3:12] = bands

# sample data and train GMM model
smp = df[sample(nrow(df), 500000), ]
smp = na.omit(smp)
mdl = Mclust(smp, 5)

full_pred = df_predict(df)
ras_pred = cbind(df[, 1:2], full_pred)
ras_pred = st_as_stars(ras_pred, dims = c("x", "y"), crs = st_crs(ras))
plot(ras_pred)
write_stars(ras_pred, tmp, type = "Byte", options = "COMPRESS=LZW")

##### terra #####
ras = rast(f)
df = as.data.frame(ras, xy = TRUE, na.rm = FALSE)
colnames(df)[3:12] = bands
full_pred = df_predict(df)
ras_pred = cbind(df[, 1:2], full_pred)
ras_pred = rast(ras_pred, crs = crs(ras))
plot(ras_pred)
writeRaster(ras_pred, tmp, overwrite = TRUE, datatype = "INT1U", gdal = "COMPRESS=LZW")
kadyb commented 2 years ago

Profiling results of the above code (spatial operations are marked with green dots). In fact, it takes the longest time to train the model and make predictions by model.

results