Closed lmz123321 closed 1 year ago
cc @XuehaiPan
I found this is an intentional behavior in the nn.MetaGradientModule
for automatic meta-module detection (line 49).
If you are passing your mlp
with mlp.training = False
, then the following code:
class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule):
def __init__(self, mlp, x0):
super().__init__()
self.mlp = mlp
will register the passed mlp
as a submodule rather than a meta-module that holds meta-parameters.
Then you will have:
list(implicit_model.meta_modules()) # -> []
list(implicit_model.meta_parameters()) # -> []
list(implicit_model.modules()) # -> [implicit_model, mlp, *mlp_submodules]
list(implicit_model.parameters()) # -> [*mlp_parameters, x]
expected:
list(implicit_model.meta_modules()) # -> [mlp, *mlp_submodules]
list(implicit_model.meta_parameters()) # -> [*mlp_parameters]
list(implicit_model.modules()) # -> [implicit_model]
list(implicit_model.parameters()) # -> [x]
The fix would be simple with only a one-line change. One approach is to register the meta-module explicitly rather than rely on the self.mlp = mlp
assignment implicitly.
class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule):
def __init__(self, mlp, x0):
super().__init__()
self.register_meta_module('mlp', mlp) # <=== HERE ===
self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)
# ...
mlp = nn.Sequential(nn.Linear(5,5), nn.BatchNorm1d(5), nn.Tanh(), nn.Linear(5,1))
x0 = torch.rand(10,5)
mlp.eval()
model = ImplicitModel(mlp, x0)
model.solve()
or you can pass your module with mlp.training = True
, and then set eval()
after it became a meta-module:
class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule):
def __init__(self, mlp, x0):
super().__init__()
self.mlp = mlp
self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)
# ...
mlp = nn.Sequential(nn.Linear(5,5), nn.BatchNorm1d(5), nn.Tanh(), nn.Linear(5,1))
x0 = torch.rand(10,5)
model = ImplicitModel(mlp, x0)
model.mlp.eval() # <=== HERE ===
model.solve()
Required prerequisites
What version of TorchOpt are you using?
0.7.0
System information
System version: 3.9.16 (main, Jan 11 2023, 16:05:54) System platform: [GCC 11.2.0] linux Torchopt version: 0.7.0 (installed via conda) Torch version: 1.13.1 Functorch version: 1.13.1
Problem description
Hi, I am using the implicit model from
torchopt.nn.ImplicitMetaGradientModule
.One of my input neural networks contains batch normalization layers, I hope to frozen them when calling
implicit_model.solve()
. So, I use thenetwork.eval()
offered by PyTorch.However, torchopt throws the following error:
ValueError: not enough values to unpack (expected 2, got 0)
from the 140-line
meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters()))
in/torchopt/diff/implicit/nn/module.py
.What I find about this bug:
self.named_meta_parameters()
is emptyReproducible example code
The Python snippets:
Traceback
Expected behavior
No response
Additional context
No response