Closed OverLordGoldDragon closed 2 years ago
There is some filter functionally already implemented, but it is currently not easily accessible to the user. I try to come up with some minimally invasive way to make it accessible.
My proposed solution to this would be an interface, similar to this:
def create_graph_from_model(model: Model, filter: Optional[Union[str, Callable]] = None) -> EnrichedNetworkNode:
By default, no filtering is done, which is the technically correct way of computing receptive field sizes.
However, this may be misleading for things such as r(max), when attention-mechanisms come into play (SE-Modules are typical example). These are low capacity, which means they will likely not contain much critical information that helps directly solving the problem, while having an infinite receptive field size, which (as you correctly observed) pollutes the computation.
A supported alternative would be to filter all infinite values if there are any non-infinite receptive field sizes present, which can be selected using the string key "inf"
in the function for extracting the graph.
If, at some point, you decide that another filter is necessary, for example a filter that combines all receptive field sizes into an average receptive field size, you can alternatively provide a callable filter Callable[[Tuple[ReceptiveFieldInfo, ...], Tuple[ReceptiveFieldInfo, ...]
that conducts your custom filtering for you.
This change permeates through multiple parts of the library, since it messes with all encoders of the framework. So, the change will take a bit of time to implement and test properly. I hope I get it done some time next week. Also, let me know if you have any other suggestions regarding the proposed handling.
Yeah, a mix of simple str
and extended Callable
sounds perfect.
The new release 1.4.2 contains the discussed changes. Here is some example code:
if __name__ == '__main__':
from rfa_toolbox import create_graph_from_pytorch_model, visualize_architecture
from torchvision.models.efficientnet import efficientnet_b7
graph = create_graph_from_pytorch_model(efficientnet_b7(), filter_rf="inf")
visualize_architecture(graph, "DUMMY", input_res=32).view()
Results in squeeze and excitation modules being ignored correctly by the rest of the network. Here one example from the model above:
Pollutes
r(max)
in SEResNets, andr(min)
isn't sufficiently informativeMeant to open this as feature request