moskomule / anatome

Ἀνατομή is a PyTorch library to analyze representation of neural networks
MIT License
61 stars 6 forks source link

Do hooks create unexpected side effects in anatome? #21

Closed brando90 closed 2 years ago

brando90 commented 2 years ago

I will try really hard to make this my last question, and I won't bother you again. Do we need to do some sort of clearing after we call hook.distance in anatome? (or deep copy the models for the code to work properly)?

e.g. modification based on your tutorial:

def cxa_dist(mdl1: nn.Module, mdl2: nn.Module, X: Tensor, layer_name: str,
             downsample_size: Optional[str] = None, iters: int = 1, cxa_dist_type: str = 'pwcca') -> float:
    import copy
    mdl1 = copy.deepcopy(mdl1)
    mdl2 = copy.deepcopy(mdl2)
    # get sim/dis functions
    hook1 = SimilarityHook(mdl1, layer_name, cxa_dist_type)
    hook2 = SimilarityHook(mdl2, layer_name, cxa_dist_type)
    mdl1.eval()
    mdl2.eval()
    for _ in range(iters):  # might make sense to go through multiple is NN is stochastic e.g. BN, dropout layers
        mdl1(X)
        mdl2(X)
    # - size: size of the feature map after downsampling
    dist = hook1.distance(hook2, size=downsample_size)
    # - remove hook, to make sure code stops being stateful (I hope)
    remove_hook(mdl1, hook1)
    remove_hook(mdl2, hook2)
    return float(dist)

def remove_hook(mdl: nn.Module, hook):
    """
    ref: https://github.com/pytorch/pytorch/issues/5037
    """
    handle = mdl.register_forward_hook(hook)
    handle.remove()
moskomule commented 2 years ago

I don't think there is a side effect (unless it's a PyTorch-side problem), but you can clear it after computing CCA.

brando90 commented 2 years ago

I don't think there is a side effect (unless it's a PyTorch-side problem), but you can clear it after computing CCA.

So if I call the above cxa_dist twice (i.e. calling your jupyter note book code twice in a row) create any problems? In my case I have a loop accross your basic anatome 101 code snippt (which I wrapped in the cxa_dist)

brando90 commented 2 years ago

The snipit I mean is essentially the same as the code above:

from anatome import SimilarityHook

model = resnet18()
hook1 = SimilarityHook(model, "layer3.0.conv1")
hook2 = SimilarityHook(model, "layer3.0.conv2")
model.eval()
with torch.no_grad():
    model(data[0])
# downsampling to (size, size) may be helpful
hook1.distance(hook2, size=8)
brando90 commented 2 years ago

I think hook1.clear() is needed, or it keeps collecting tensors. So yes it does have side effects.

brando90 commented 2 years ago

@moskomule can we put the hook1.clear() in the readme tutorial? otherwise OOMs happen since the hooks collect tensors without stop.