yihui / knitr

A general-purpose tool for dynamic report generation in R
https://yihui.org/knitr/
2.36k stars 873 forks source link

Problem caching instances of torch modules and datasets #2339

Open gavril0 opened 2 months ago

gavril0 commented 2 months ago

Caching chunks that create an instance of torch module or of a torch dataset yields an external pointer is not valid error when the instance is used in another chunk.

Example with torch module:

    ```{r, cache=TRUE}
    lin <- nn_linear(2, 3)
    # torch_save(lin, "lin.pt")
    ``` 

    ```{r}
    # lin <- torch_load("lin.pt")
    x <- torch_randn(2)
    lin$forward(x)
    ```

Example with torch dataset:

    ```{r, cache=TRUE}
    ds_gen <- dataset(
      initialize = function() {
        self$x <- torch_tensor(1:10, dtype=torch_long())
        }, 
        .getitem = function(index) {
        self$x[index]
      },
      .length = function() {
         length(self$x)
      }
    )

    ds <- ds_gen()
    ``` 

    ```{r}
    ds[1:3]
    ```

If there is no cache, the chunks are executed without problems. However, when a cache exists, an error is created when trying to access the cached instance of the module or of the dataset:

 Error in cpp_tensor_dim(x$ptr) : external pointer is not valid

This might be due to the fact that R torch package relies on reference classes (R6 and/or R7) and could be related to issue #2176. In any case, caching would be useful to cache trained instance of a module or instances of datasets which involve a lot processing during initialization.

At the moment, the only alternative is to save the torch model in the cached chunk with torch_save and load it in the uncached chunk with torch_load (see comments in the chunk above). However, afaik, there is no method to save and load torch datasets.