Closed sebffischer closed 4 months ago
Also, the Cpp code to clone tensors / buffers / parameters behaves differently, right?
I also think it would be nice if torch_clone(tensor)
and tensor$clone()
behaved identically.
There is also another discrepancy between the clonee and the cloned object:
library(torch)
lin = nn_linear(1, 1)
lin$parameters$weight$requires_grad
#> [1] TRUE
lin$clone(deep = TRUE)$parameters$weight$requires_grad
#> [1] FALSE
Created on 2024-02-08 with reprex v2.0.2
one more:
library(torch)
nn_linear(1, 1)$train()$clone(deep = TRUE)
#> Error in FUN(X[[i]], ...): not an environment
nn_linear(1, 1)$eval()$clone(deep = TRUE)
#> Error in FUN(X[[i]], ...): not an environment
Created on 2024-02-08 with reprex v2.0.2
Also, cloning repeatedly causes issues and builds up a structure of parent environments.
This is, because the clone
that is being retrieved here https://github.com/mlverse/torch/blob/0e9fdd78852601b655eed48a2eaf5e22033dead0/R/nn.R#L521 when cloning the second time is already the patched version and not the original R6 clone implementation.
The patched version still has the original version as its enclosing environment so it should still work but will repeatedly call the clone()
method of the enclosing environment (https://github.com/mlverse/torch/blob/0e9fdd78852601b655eed48a2eaf5e22033dead0/R/nn.R#L549) until it reaches the top-level clone call, i.e. R6's clone method.
library(torch)
n = nn_linear(1, 1)
head(attr(n, "module")$clone, n = 1)
#>
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n, "module")$clone |> environment() |> with(clone), n = 1)
#>
#> 1 function (deep = FALSE)
n1 = n$clone(deep = TRUE)
head(attr(n1, "module")$clone, n = 1)
#>
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n1, "module")$clone |> environment() |> with(clone), n = 1)
#>
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n1, "module")$clone |> environment() |> with(clone) |> environment() |> with(clone), n = 1)
#>
#> 1 function (deep = FALSE)
identical(
attr(n1, "module")$clone |> environment(),
attr(n1, "module")$clone |> environment() |> with(clone) |> environment()
)
#> [1] FALSE
attr(n1, "module")$clone |> environment() |> names()
#> [1] "clone" "f" "instance"
attr(n1, "module")$clone |> environment() |> with(clone) |> environment() |> names()
#> [1] "clone" "f" "instance"
Created on 2024-02-10 with reprex v2.0.2
Also, I included support for a private clone finalizer method. In mlr3torch we need something like this, because we have an nn_module
that has an R6 class containing modules, but these modules also need to be registered in the nn_module class. So the reference identity of these objects needs to be kept when cloning and the only solution I came up with is to allow for some hook after calling clone. Let me know what you think :)
I also just saw: https://github.com/r-lib/R6/pull/273, which would make the post_clone hook officially supported by R6.
So there is at least one more issue shown in the reprex below. When creating the state_dict
, the parameters of the children are not being collected here: https://github.com/mlverse/torch/blob/0e9fdd78852601b655eed48a2eaf5e22033dead0/R/nn.R#L543
I think we need to recurse through the children and things should probably work.
Setting replace_values = replace_values
might also work in some cases, but this runs into the issue when different submodules reference the same tensors.
library(torch)
nn_test = nn_module("test", initialize = function() {
self$l = nn_module_list(list(nn_linear(1, 1)))
},
forward = function(x) {
self$l[[1]](x)
}
)()
nn_test1 = nn_test$clone(deep = TRUE)
nn_test$clone(deep = TRUE)
#> An `nn_module` containing 2 parameters.
#>
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • l: <nn_module_list> #2 parameters
l1 = nn_test$l$modules[[2]]
l2 = nn_test1$l$modules[[2]]
identical(l1, l2)
#> [1] TRUE
Created on 2024-02-10 with reprex v2.0.2
The current workaround for the clone method caused the cloned object to reference the original object. It was caused by this line: https://github.com/mlverse/torch/blob/0e9fdd78852601b655eed48a2eaf5e22033dead0/R/nn.R#L523
This implied that the size of the cloned object was larger than the original object:
library(torch)
pryr::object_size(nn_relu())
#> 484.47 kB
pryr::object_size(nn_relu()$clone(deep = TRUE))
#> 492.42 kB
Created on 2024-02-12 with reprex v2.0.2
@dfalbel I am done here and would love to get your feedback whether you think these changes make sense :)
Ok, now I think it is actually ready from my side
@dfalbel we can also have a call where I can explain some of the changes if you have the time / you think this is useful or necessary. Otherwise I can also give more details here
Edit: please see comment below, this no longer applies.
@sebffischer I wonder if we instead of renaming clone2
to clone
, we could find a different name that clearly states what it does. The torch documentation clearly states:
This function is differentiable, so gradients will flow back from the result of this operation to input. To create a tensor without an autograd relationship to input see detach().
Thus renaming these methods will certainly cause problems in other codebases that rely on that behavior. For instance, it breaks some torch optimizers that do use clone()
for that.
Besides that, the PR looks great! Thank you very much for working on this.
Ok, so renaming works fine, we just had to support the other arguments to clone()
that were not specified.
The only thing I'm not convinced is that we want to set requires_grad
, do you know why exactly we need this? Is it not respected?
@dfalbel Thanks for the feedback. Regarding the call to $requires_grad_()
: think I mixed something up. Indeed, the $requires_grad
field does not need to be modified manually in the tensor's $clone()
method.
please don't merge yet I want to add one more refactor
Done now
Thanks @sebffischer ! Looks great! I also ran luz and minhub tests and they all passed against this version.
Addresses https://github.com/mlverse/torch/issues/1126, where you offered to take care of the renaming. I.e. the
$clone2()
method should just be renamed to$clone()
, which currently has no effect, as thefind_method()
function (iirc) still finds thetorch_clone()
function that is auto-generated.