KamitaniLab / bdpy

Python package for brain decoding analysis (BrainDecoderToolbox2 data format, machine learning analysis, functional MRI)
MIT License
33 stars 22 forks source link

[fix] feature extraction on detach #56

Closed micchu closed 1 year ago

micchu commented 1 year ago

FeatureExtractorHandleについて,callの中でdetachする動作をデフォルトにすると,icnnで計算グラフが切れるバグが発生します. 従って,FeatureExtractorのdetach引数の値によって,detachするFeatureExtractorHandleクラスと,detachしないFeatureExtractorHandleクラスのどちらをインスタンスするかを切り替えています. ※ FeatureExtractorHandleにうまく引数を渡せなかったのでこのような形式にしていますが,もっと良い実装はあると思います.

ganow commented 1 year ago

@micchu

※ FeatureExtractorHandleにうまく引数を渡せなかった

https://github.com/KamitaniLab/bdpy/blob/aad35e6c9c8272fc82059d06a64c6f042b2f37ca/bdpy/dl/torch/torch.py#L98-L110

例えば上記のコードについて以下のような書き換えをやってもうまく動作しないといったことでしょうか? forward hookについて、 object.__call__(...) はすでにtorch側で決まっているAPIに従う必要がありますがinitializationはこちらのコードベースにあるのでこれで動きそうな予感がしました。

class FeatureExtractorHandle(object):
    def __init__(self, detach: bool = False):
        self.outputs: List[torch.Tensor] = []
        self.detach = detach

    def __call__(
            self,
            module: nn.Module, module_in: Any,
            module_out: torch.Tensor
    ) -> None:
        self.outputs.append(module_out.detach().clone() if self.detach else module_out)

    def clear(self):
        self.outputs = []
ganow commented 1 year ago

@ShuntaroAoki

ReLU(inplace=False) による書き換えを避けたいから tensor.detach().clone() を実行するというのは、よりよい解決ができそうな気がしました。本来 ReLUinplace=True オプションはGPUリソースの節約のためにいれるフラグだと思いますが、 clone() はもとのテンソルのコピーを発生させるため(参考)、結果的にメモリ使用量が増えてしまうように思います。

本質的には、3rd partyが作成したモデルの ReLU(inplace=True)LeakyReLU(inplace=True) などの部分を自動的に inplace=False に書き換えるコードさえあれば FeatureExtractor 側は全て detach=False が標準の挙動で問題ないように思います。

ということで、上記を実現するためのサンプルコードを作成してみましたがいかがでしょうか?ご意見いただきたいです。

from typing import Sequence, Dict, Any, Optional, Type
import torch.nn as nn

def modify_constants(
        model: nn.Module, config: Dict[str, Any],
        bounds: Optional[Sequence[Type]] = None) -> None:
    '''Modify the constants of the model.

    Args:
    -----
        model: The model to be modified.
        config: The configuration of the constants.
        bounds: The sequence of types of the modules to be modified.

    Example:
    --------
        >>> from torchvision.models import resnet18
        >>> model = resnet18()
        >>> modify_constants(model, {'inplace': False}, bounds=(nn.ReLU, nn.LeakyReLU))
    '''

    for module in model.modules():
        if not isinstance(module, tuple(bounds)):
            continue
        if not hasattr(module, '__constants__'):
            continue
        for key, value in config.items():
            if key in getattr(module, '__constants__'):
                setattr(module, key, value)

if __name__ == '__main__':
    from torchvision.models import resnet18, vit_b_16

    print('Load ResNet18')
    model = resnet18()

    print('Print all the modules which has inplace')
    [print(module) for module in model.modules() if hasattr(module, 'inplace')]
    # Output: ReLU(inplace=True) ...

    print('Modify the constants of ReLU()')
    modify_constants(model, {'inplace': False}, bounds=(nn.ReLU,))

    print('Print all the modules which has inplace')
    [print(module) for module in model.modules() if hasattr(module, 'inplace')]
    # Output: ReLU() ...

    print('')

    print('Load ViT-B/16')
    model = vit_b_16()

    print('Print all the modules which has inplace')
    [print(module) for module in model.modules() if hasattr(module, 'inplace')]
    # Output: Dropout(p=0.0, inplace=False) ...

    print('Modify the constants of ReLU()')
    modify_constants(model, {'inplace': False}, bounds=(nn.ReLU,))

    print('Print all the modules which has inplace')
    [print(module) for module in model.modules() if hasattr(module, 'inplace')]
    # Output: Dropout(p=0.0, inplace=False) ...

一点注意は、本来 module.__constants__ に列挙されているプロパティは変更されるようにデザインされていないため、上記のコードを実行するタイミングに注意が必要です。具体的には、 このアノテーションはjitコンパイル時に利用されるよう なので、jitコンパイルをする場合にはそれよりも に上記コードを実行する必要がありそうです。