metaopt / torchopt

TorchOpt is an efficient library for differentiable optimization built upon PyTorch.
https://torchopt.readthedocs.io
Apache License 2.0
540 stars 35 forks source link

[BUG] Got empty meta-parameters using `ImplicitMetaGradientModule` #144

Closed lmz123321 closed 1 year ago

lmz123321 commented 1 year ago

Required prerequisites

What version of TorchOpt are you using?

0.7.0

System information

System version: 3.9.16 (main, Jan 11 2023, 16:05:54) System platform: [GCC 11.2.0] linux Torchopt version: 0.7.0 (installed via conda) Torch version: 1.13.1 Functorch version: 1.13.1

Problem description

Hi, I am using the implicit model from torchopt.nn.ImplicitMetaGradientModule.

One of my input neural networks contains batch normalization layers, I hope to frozen them when calling implicit_model.solve(). So, I use the network.eval() offered by PyTorch.

However, torchopt throws the following error:

from the 140-line meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) in /torchopt/diff/implicit/nn/module.py.


What I find about this bug:

Reproducible example code

The Python snippets:

import torch
import torch.nn as nn
import torchopt

_ = torch.manual_seed(123)
torch.set_default_dtype(torch.float64)

class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule):
    def __init__(self, mlp, x0):
        super().__init__()
        self.mlp = mlp
        self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)

    def objective(self):
        return self.mlp(self.x).mean()

    @torch.enable_grad()
    def solve(self):
        optimizer = torch.optim.Adam([self.x], lr=0.01)
        for epoch in range(100):
            optimizer.zero_grad()
            loss = self.objective()
            loss.backward(inputs=[self.x])
            optimizer. step()

mlp = nn.Sequential(nn.Linear(5,5), nn.BatchNorm1d(5), nn.Tanh(), nn.Linear(5,1))
_ = mlp.eval()

# this will work
# for m in mlp.modules():
#   if isinstance(m, nn.BatchNorm1d):
#       m.eval()

x0 = torch.rand(10,5)

model = ImplicitModel(mlp, x0)
model. Solve()

Traceback

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/Miniconda3/envs/torch1.13/lib/python3.9/site-packages/torchopt/diff/implicit/n │
│ n/module.py:140 in wrapped                                                                       │
│                                                                                                  │
│   137 │   def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any:                │
│   138 │   │   """Solve the optimization problem."""                                              │
│   139 │   │   params_names, flat_params = tuple(zip(*self.named_parameters()))                   │
│ ❱ 140 │   │   meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters()))    │
│   141 │   │                                                                                      │
│   142 │   │   flat_optimal_params, output = stateless_solver_fn(                                 │
│   143 │   │   │   flat_params,                                                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: not enough values to unpack (expected 2, got 0)

Expected behavior

No response

Additional context

No response

Benjamin-eecs commented 1 year ago

cc @XuehaiPan

XuehaiPan commented 1 year ago

I found this is an intentional behavior in the nn.MetaGradientModule for automatic meta-module detection (line 49).

https://github.com/metaopt/torchopt/blob/30db8ecb4b26ed2815edafd49cd84204a288154c/torchopt/nn/module.py#L43-L57

If you are passing your mlp with mlp.training = False, then the following code:

class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule):
    def __init__(self, mlp, x0):
        super().__init__()
        self.mlp = mlp

will register the passed mlp as a submodule rather than a meta-module that holds meta-parameters.

Then you will have:

list(implicit_model.meta_modules())    # -> []
list(implicit_model.meta_parameters()) # -> []
list(implicit_model.modules())         # -> [implicit_model, mlp, *mlp_submodules]
list(implicit_model.parameters())      # -> [*mlp_parameters, x]

expected:

list(implicit_model.meta_modules())    # -> [mlp, *mlp_submodules]
list(implicit_model.meta_parameters()) # -> [*mlp_parameters]
list(implicit_model.modules())         # -> [implicit_model]
list(implicit_model.parameters())      # -> [x]

The fix would be simple with only a one-line change. One approach is to register the meta-module explicitly rather than rely on the self.mlp = mlp assignment implicitly.

class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule):
    def __init__(self, mlp, x0):
        super().__init__()
        self.register_meta_module('mlp', mlp)  # <=== HERE ===
        self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)

# ...

mlp = nn.Sequential(nn.Linear(5,5), nn.BatchNorm1d(5), nn.Tanh(), nn.Linear(5,1))
x0 = torch.rand(10,5)

mlp.eval()
model = ImplicitModel(mlp, x0)
model.solve()

or you can pass your module with mlp.training = True, and then set eval() after it became a meta-module:

class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule):
    def __init__(self, mlp, x0):
        super().__init__()
        self.mlp = mlp
        self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)

# ...

mlp = nn.Sequential(nn.Linear(5,5), nn.BatchNorm1d(5), nn.Tanh(), nn.Linear(5,1))
x0 = torch.rand(10,5)

model = ImplicitModel(mlp, x0)
model.mlp.eval()  # <=== HERE ===
model.solve()