apple / corenet

CoreNet: A library for training deep neural networks
Other
6.72k stars 518 forks source link

Allow applying on all modules, not just immediate children #10

Closed hub-bla closed 1 month ago

hub-bla commented 2 months ago

I've made nested module selection based on the way the CSS children selector works. By using '>' we can now select nested modules.

Example:

opts = argparse.Namespace(**{"model.freeze_modules": "model1>ins_model"})

inside_model2 = nn.Sequential(
     OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ])
)

inside_model = nn.Sequential(
     OrderedDict([
          ('ins_model', inside_model2),
          ('conv1', nn.Conv2d(1,20,5))
        ])
)

model = nn.Sequential(
     OrderedDict([
          ('model1', inside_model),
          ('conv1', nn.Conv2d(20,64,5)),
          ('conv2', nn.Conv2d(20,64,5))
        ])
)

print(freeze_modules_based_on_opts(opts, model))

returns: example_result

mohammad7t commented 2 months ago

Hi @hub-bla . Thank you for your contribution. Since freeze_modules accepts regex, I'm wondering if a more flexible regex could select nested modules with the existing code? For example, the following regex seems to work for model1>conv1:

import argparse
from collections import OrderedDict
from torch import nn

from corenet.modeling.misc.common import freeze_modules_based_on_opts

opts = argparse.Namespace(**{"model.freeze_modules": r"model1(.*)\.conv1"})

inside_model2 = nn.Sequential(
     OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ])
)

inside_model = nn.Sequential(
     OrderedDict([
          ('ins_model', inside_model2),
          ('conv1', nn.Conv2d(1,20,5))
        ])
)

model = nn.Sequential(
     OrderedDict([
          ('model1', inside_model),
          ('conv1', nn.Conv2d(20,64,5)),
          ('conv2', nn.Conv2d(20,64,5))
        ])
)

print(freeze_modules_based_on_opts(opts, model))
hub-bla commented 2 months ago

Hi @mohammad7t, I agree that existing code can support that operations using loop on named_parameters. To be honest I didn't checked if it works before I started implementing the enhancement. I followed a comment that is above the loop with named_children. # TODO: allow applying on all modules, not just immediate chidren? How?
The thing that my code does additionally is that it reduces number of logs. For example, If you want to freeze nested module that is made of a lot of nesting, it won't produce log for every parameter that is going to be freezed. Instead, it will only show that the whole module is now frozen.

To be honest I wonder if that loop with named_modules is even neccessary beacuase everything could by done using as you provided. It would also remove this issue

mohammad7t commented 2 months ago

I see where you are coming from! I agree that the # TODO: allow applying on all modules, not just immediate chidren? How? is confusing. I think the TODO is related to applying force_eval on nested modules: https://github.com/apple/corenet/blob/aaa14a602d22fe3020eb24096483cf2b8c8af4c0/corenet/modeling/misc/common.py#L200-L208

I wonder if that loop with named_modules is even necessary because everything could by done using as you provided.

That's a good question. I think the only reason we need the loop with named_modules is to apply force_eval as mentioned above.

I'm not entirely sure, what the best solution is right now. Let me think a bit more and get back to you. Thinking loudly, I guess we don't need to support the ">" css operator, but the bfs is probably a good idea to address the TODO. What do you think?

Thanks again!

hub-bla commented 2 months ago

Thank you for clarification! Now, I get it and I agree with everything you said.

Regarding the ">" selector, I'm not sure if it's unnecessary when applying bfs. The way it works now is that the input string is splitted by this symbol and then those chunks that might be a regex expression or not, are then passed to bfs. Without it, we had to split the string by dot symbol which is a special character in regex and this might cause some problems. I'll try to think about that too.

mohammad7t commented 1 month ago

Hi again, Quick update - I don't have a clean solution for resolving the TODO for skipping nested train/eval calls at the moment. Please feel free to re-open/update this issue in the future if there are updates.