KamitaniLab / bdpy

Python package for brain decoding analysis (BrainDecoderToolbox2 data format, machine learning analysis, functional MRI)
MIT License
33 stars 22 forks source link

Update feature extractor #54

Closed ganow closed 1 year ago

ganow commented 1 year ago

We have implemented the followings:

Note:

This PR will have the breaking change in bdpy's API. FeatureExtractor() now only accepts the layer name which has the following formats:

ganow commented 1 year ago

@ShuntaroAoki Question: Why FeatureExtractor() accepts layer=None?

https://github.com/KamitaniLab/bdpy/blob/bdc866716aab675b311846528a4a275bb075d0fe/bdpy/dl/torch/torch.py#L11-L14

If we set it as self.__layer = None, the following code will raise an error.

https://github.com/KamitaniLab/bdpy/blob/bdc866716aab675b311846528a4a275bb075d0fe/bdpy/dl/torch/torch.py#L23-L26

ganow commented 1 year ago

I found the useful method nn.Module.get_submodule(target: str) -> nn.Module. We can use it like:

class MockModule(nn.Module):
    def __init__(self):
        super(MockModule, self).__init__()
        self.layer1 = nn.Linear(10, 10)
        self.layers = nn.Sequential(
            nn.Conv2d(1, 1, 3),
            nn.Conv2d(1, 1, 3),
            nn.Module(),
            nn.Sequential(
                nn.Conv2d(1, 1, 4),
                nn.Conv2d(1, 1, 8),
            )
        )
        inner_network = self.layers[-2]
        inner_network.features = nn.Sequential(
            nn.Conv2d(1, 1, 5),
            nn.Conv2d(1, 1, 5)
        )

mock = MockModule()
mock.get_submodule('layer1')  # -> Linear(in_features=10, out_features=10, bias=True)
mock.get_submodule('layers.0')  # -> Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1))
mock.get_submodule('layers.2.features.0')  # -> Conv2d(1, 1, kernel_size=(5, 5), stride=(1, 1))
mock.get_submodule('layers.3.0')  # -> Conv2d(1, 1, kernel_size=(4, 4), stride=(1, 1))

Perhaps we should implement _parse_layer_name() by using it.

Pros of using get_submodule():

Cons:

Reference: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_submodule