NorskRegnesentral / shapr

Explaining the output of machine learning models with more accurately estimated Shapley values
https://norskregnesentral.github.io/shapr/
Other
138 stars 32 forks source link

Restructured torch modules to support shapr installation without torch #393

Closed LHBO closed 2 months ago

LHBO commented 2 months ago

In this PR, we refactor the vaeac approach so that the torch-modules are initiated through functions. Previously, when they were not inside functions, installation of shapr failed as it tried to evaluate torch::, which was not installed. We tried to fix this in #390, but do to technical issues for me, I had to make a new PR.

Also fixed some typos in the Roxygen documentation and ensured that the progressr progress bar inside the vaeac approach is only called if progressr is available.

Details: Instead of having

vaeac_dataset <-  torch::dataset(
  name = "vaeac_dataset", 
  initialize = function(X, one_hot_max_sizes) {...},
  .getbatch = function(index) {...},
  .length = function() {...}
 )

we replaced it with

vaeac_dataset <- function(X, one_hot_max_sizes) {
  vaeac_dataset_tmp <- torch::dataset(
    name = "vaeac_dataset", 
    initialize = function(X, one_hot_max_sizes) {...},
    .getbatch = function(index) {...},
    .length = function() {...}
  )
  return(vaeac_dataset_tmp(X = X, one_hot_max_sizes = one_hot_max_sizes))
}

Some extra care had to be given to memory_layer, which had an internal and shared environment between all instances of memory_layer. In the new version, we have to create an environment first and then send this to the new version of memory_layer.