ViCCo-Group / thingsvision

Python package for extracting representations from state-of-the-art computer vision models
https://vicco-group.github.io/thingsvision/
MIT License
157 stars 21 forks source link

Check validity of module name #93

Closed LukasMut closed 2 years ago

LukasMut commented 2 years ago

We should add a function that checks the validity of module_name (i.e., is the user-specified module name a valid module that exists?) before feature extraction starts. Else, the feature extraction method will raise a KeyError after the first iteration which is probably fine but definitely not great. I am thinking about something along the lines of

def get_module_names(self) -> List[str]:
    if self.backend == 'pt':
        module_names, _ = zip(*self.model.named_modules())
        module_names = list(filter(lambda n: len(n) > 0, module_names))
    else:
        module_names = [l._name for l in self.model.submodules]
    return module_names

def _is_valid_module(self, module_name: str) -> bool:
    valid_names = self.get_module_names()
    return (module_name in valid_names)

which is supposed to be evaluated prior to executing the main feature extraction loop. Therefore, this should probably be evaluated in extractor.extract_features(...).

LukasMut commented 2 years ago

This does not seem to be necessary if we resolve issue #89.

andropar commented 2 years ago

I think its still necessary, if we continue allowing users to extract features for individual network modules (which I also think we should do)

LukasMut commented 2 years ago

@andropar I agree.