First, thanks for making this. Lifesaver. Two thoughts (Fwiw, the nested functions, higher-order functions and decorators make things a biiiiit hard to follow when debugging):
I initially dun goofed and didn't eval the model (even though the very example notebook I'm using from lucent does lol). Maybe the hook_model function could check for nonetypes and tell the user to eval, if no saved feature maps are found?
PyTorch module names usually use dot notation. Maybe use dots instead of underscores? Or just tell the user which feature map names are available and the user'll figure it out quickly enough
def hook(layer):
if layer == "input":
out = image_f()
elif layer == "labels":
out = list(features.values())[-1].features
else:
assert layer in features, f"Invalid layer {layer}. Pick from one of {features.keys()}" # suggestion 2 ish
out = features[layer].features
assert out is not None, "There are no saved feature maps. Make sure to put the model in eval mode, like so: `model.to(device).eval()`. See Lucent notebook for example." # suggestion 1, tell user to eval
return out
*I ran it on resnet18. Gorgeous and worked out of the box btw.
First, thanks for making this. Lifesaver. Two thoughts (Fwiw, the nested functions, higher-order functions and decorators make things a biiiiit hard to follow when debugging):
hook_model
function could check for nonetypes and tell the user to eval, if no saved feature maps are found?Suggested replacement for this function: https://github.com/greentfrapp/lucent/blob/a2b015ce95f29460a329f750428077bcde5e4e94/lucent/optvis/render.py#L194
*I ran it on resnet18. Gorgeous and worked out of the box btw.