mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
483 stars 66 forks source link

fix(cloning): metadata, finalizer, and repeated cloning #1134

Closed sebffischer closed 4 months ago

sebffischer commented 4 months ago

Addresses https://github.com/mlverse/torch/issues/1126, where you offered to take care of the renaming. I.e. the $clone2() method should just be renamed to $clone(), which currently has no effect, as the find_method() function (iirc) still finds the torch_clone() function that is auto-generated.

sebffischer commented 4 months ago

Also, the Cpp code to clone tensors / buffers / parameters behaves differently, right?

sebffischer commented 4 months ago

I also think it would be nice if torch_clone(tensor) and tensor$clone() behaved identically.

sebffischer commented 4 months ago

There is also another discrepancy between the clonee and the cloned object:

library(torch)
lin = nn_linear(1, 1)
lin$parameters$weight$requires_grad
#> [1] TRUE
lin$clone(deep = TRUE)$parameters$weight$requires_grad
#> [1] FALSE

Created on 2024-02-08 with reprex v2.0.2

sebffischer commented 4 months ago

one more:

library(torch)
nn_linear(1, 1)$train()$clone(deep = TRUE)
#> Error in FUN(X[[i]], ...): not an environment
nn_linear(1, 1)$eval()$clone(deep = TRUE)
#> Error in FUN(X[[i]], ...): not an environment

Created on 2024-02-08 with reprex v2.0.2

sebffischer commented 4 months ago

Also, cloning repeatedly causes issues and builds up a structure of parent environments. This is, because the clone that is being retrieved here https://github.com/mlverse/torch/blob/0e9fdd78852601b655eed48a2eaf5e22033dead0/R/nn.R#L521 when cloning the second time is already the patched version and not the original R6 clone implementation. The patched version still has the original version as its enclosing environment so it should still work but will repeatedly call the clone() method of the enclosing environment (https://github.com/mlverse/torch/blob/0e9fdd78852601b655eed48a2eaf5e22033dead0/R/nn.R#L549) until it reaches the top-level clone call, i.e. R6's clone method.

library(torch)

n = nn_linear(1, 1)
head(attr(n, "module")$clone, n = 1)
#>                                                       
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n, "module")$clone |> environment() |> with(clone), n = 1)
#>                           
#> 1 function (deep = FALSE)

n1 = n$clone(deep = TRUE)
head(attr(n1, "module")$clone, n = 1)
#>                                                       
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n1, "module")$clone |> environment() |> with(clone), n = 1)
#>                                                       
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n1, "module")$clone |> environment() |> with(clone) |> environment() |> with(clone), n = 1)
#>                           
#> 1 function (deep = FALSE)

identical(
  attr(n1, "module")$clone |> environment(),
  attr(n1, "module")$clone |> environment() |> with(clone) |> environment()
)
#> [1] FALSE

attr(n1, "module")$clone |> environment() |> names()
#> [1] "clone"    "f"        "instance"
attr(n1, "module")$clone |> environment() |> with(clone) |> environment() |> names()
#> [1] "clone"    "f"        "instance"

Created on 2024-02-10 with reprex v2.0.2

sebffischer commented 4 months ago

Also, I included support for a private clone finalizer method. In mlr3torch we need something like this, because we have an nn_module that has an R6 class containing modules, but these modules also need to be registered in the nn_module class. So the reference identity of these objects needs to be kept when cloning and the only solution I came up with is to allow for some hook after calling clone. Let me know what you think :)

sebffischer commented 4 months ago

I also just saw: https://github.com/r-lib/R6/pull/273, which would make the post_clone hook officially supported by R6.

sebffischer commented 4 months ago

So there is at least one more issue shown in the reprex below. When creating the state_dict, the parameters of the children are not being collected here: https://github.com/mlverse/torch/blob/0e9fdd78852601b655eed48a2eaf5e22033dead0/R/nn.R#L543

I think we need to recurse through the children and things should probably work.

Setting replace_values = replace_values might also work in some cases, but this runs into the issue when different submodules reference the same tensors.

library(torch)
nn_test = nn_module("test", initialize = function() {
  self$l = nn_module_list(list(nn_linear(1, 1)))
  },
  forward = function(x) {
    self$l[[1]](x)
  }
)()

nn_test1 = nn_test$clone(deep = TRUE)
nn_test$clone(deep = TRUE)
#> An `nn_module` containing 2 parameters.
#> 
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • l: <nn_module_list> #2 parameters
l1 = nn_test$l$modules[[2]]
l2 = nn_test1$l$modules[[2]]
identical(l1, l2)
#> [1] TRUE

Created on 2024-02-10 with reprex v2.0.2

sebffischer commented 4 months ago

The current workaround for the clone method caused the cloned object to reference the original object. It was caused by this line: https://github.com/mlverse/torch/blob/0e9fdd78852601b655eed48a2eaf5e22033dead0/R/nn.R#L523

This implied that the size of the cloned object was larger than the original object:

library(torch)

pryr::object_size(nn_relu())
#> 484.47 kB
pryr::object_size(nn_relu()$clone(deep = TRUE))
#> 492.42 kB

Created on 2024-02-12 with reprex v2.0.2

sebffischer commented 4 months ago

@dfalbel I am done here and would love to get your feedback whether you think these changes make sense :)

sebffischer commented 4 months ago

Ok, now I think it is actually ready from my side

sebffischer commented 4 months ago

@dfalbel we can also have a call where I can explain some of the changes if you have the time / you think this is useful or necessary. Otherwise I can also give more details here

dfalbel commented 4 months ago

Edit: please see comment below, this no longer applies.

@sebffischer I wonder if we instead of renaming clone2 to clone, we could find a different name that clearly states what it does. The torch documentation clearly states:

This function is differentiable, so gradients will flow back from the result of this operation to input. To create a tensor without an autograd relationship to input see detach().

Thus renaming these methods will certainly cause problems in other codebases that rely on that behavior. For instance, it breaks some torch optimizers that do use clone() for that.

Besides that, the PR looks great! Thank you very much for working on this.

dfalbel commented 4 months ago

Ok, so renaming works fine, we just had to support the other arguments to clone() that were not specified. The only thing I'm not convinced is that we want to set requires_grad, do you know why exactly we need this? Is it not respected?

sebffischer commented 4 months ago

@dfalbel Thanks for the feedback. Regarding the call to $requires_grad_(): think I mixed something up. Indeed, the $requires_grad field does not need to be modified manually in the tensor's $clone() method.

sebffischer commented 4 months ago

please don't merge yet I want to add one more refactor

sebffischer commented 4 months ago

Done now

dfalbel commented 4 months ago

Thanks @sebffischer ! Looks great! I also ran luz and minhub tests and they all passed against this version.