mert-kurttutan / torchview

torchview: visualize pytorch models
https://torchview.dev
MIT License
791 stars 36 forks source link

Names of Modules? #103

Open Erotemic opened 1 year ago

Erotemic commented 1 year ago

Is your feature request related to a problem? Please describe.

I'm looking for a way to visualize the information flow of a network in terms of nested module names. I.e. if we extract torch modules of specific types like this:

        def model_layers(model):
            """ Extract named "leaf" layers from a module """
            stack = [('', '', model)]
            while stack:
                prefix, basename, item = stack.pop()
                name = '.'.join([p for p in [prefix, basename] if p])
                if isinstance(item, torch.nn.modules.conv._ConvNd):
                    yield name, item
                elif isinstance(item, torch.nn.modules.batchnorm._BatchNorm):
                    yield name, item
                elif hasattr(item, 'reset_parameters'):
                    yield name, item

                child_prefix = name
                for child_basename, child_item in list(item.named_children())[::-1]:
                    stack.append((child_prefix, child_basename, child_item))

which for torchvision.resenet18 looks like:

['conv1', 'bn1', 'layer1.0.conv1', 'layer1.0.bn1', 'layer1.0.conv2', 'layer1.0.bn2', 'layer1.1.conv1', 'layer1.1.bn1', 'layer1.1.conv2', 'layer1.1.bn2', 'layer2.0.conv1', 'layer2.0.bn1', 'layer2.0.conv2', 'layer2.0.bn2', 'layer2.0.downsample.0', 'layer2.0.downsample.1', 'layer2.1.conv1', 'layer2.1.bn1', 'layer2.1.conv2', 'layer2.1.bn2', 'layer3.0.conv1', 'layer3.0.bn1', 'layer3.0.conv2', 'layer3.0.bn2', 'layer3.0.downsample.0', 'layer3.0.downsample.1', 'layer3.1.conv1', 'layer3.1.bn1', 'layer3.1.conv2', 'layer3.1.bn2', 'layer4.0.conv1', 'layer4.0.bn1', 'layer4.0.conv2', 'layer4.0.bn2', 'layer4.0.downsample.0', 'layer4.0.downsample.1', 'layer4.1.conv1', 'layer4.1.bn1', 'layer4.1.conv2', 'layer4.1.bn2', 'fc']

I would like to be able to determine what layers conceptually connected (i.e. the outputs of one layer eventually make it to the inputs of another layer).

Describe the solution you'd like

Currently when I do something like:

        import torchvision
        from torchview import draw_graph
        net = torchvision.models.resnet18()
        model_graph = draw_graph(net, input_size=(2, 3, 224, 224), device='meta')

I don't see any recorded layer names. I'm wondering if it's possible in collect graph (https://github.com/mert-kurttutan/torchview/blob/main/torchview/computation_graph.py#L188) to also associate layer names with nodes as they are extracted.

Describe alternatives you've considered

Tried doing this myself. There are a lot of corner cases. This library seems like it has the best approach for torch graph extaction I've seen so far.

Additional context

I raised a similar issue on the torch discussion page: https://discuss.pytorch.org/t/tracing-a-graph-of-torch-layers/187615

Erotemic commented 1 year ago

A MWE of something close to what I want is:

    import torchvision
    from torchview import draw_graph
    import torch
    import networkx as nx

    def model_layers(model):
        """ Extract named "leaf" layers from a module """
        stack = [('', '', model)]
        while stack:
            prefix, basename, item = stack.pop()
            name = '.'.join([p for p in [prefix, basename] if p])
            if isinstance(item, torch.nn.modules.conv._ConvNd):
                yield name, item
            elif isinstance(item, torch.nn.modules.batchnorm._BatchNorm):
                yield name, item
            elif hasattr(item, 'reset_parameters'):
                yield name, item

            child_prefix = name
            for child_basename, child_item in list(item.named_children())[::-1]:
                stack.append((child_prefix, child_basename, child_item))

    # Create example network
    net = torchvision.models.resnet18()
    model_graph = draw_graph(net, input_size=(2, 3, 224, 224), device='meta')

    # Remember the dotted layer name associated with each torch.Module
    # instance.  Usually a module will just have one name associated to an
    # instance, but it could have more than one.
    from collections import defaultdict
    named_layers = list(model_layers(net))
    id_to_names = defaultdict(list)
    for name, layer in named_layers:
        layer_id = id(layer)
        id_to_names[layer_id].append(name)

    def make_label(n, data):
        """ Create a nice printable label """
        n_id = id(n)
        n_id_str = str(n_id)
        parts = []
        if 'layer_name' in data:
            parts.append(data['layer_name'] + ':')
        parts.append(n.name)
        if n_id_str in model_graph.id_dict:
            idx = model_graph.id_dict[n_id_str]
            parts.append(f':{idx}')

        if n_id in id_to_names:
            parts.append(' ' + id_to_names[n_id])

        label = ''.join(parts)
        return label

    # Build a networkx version of the torchview model graph
    graph = nx.DiGraph()
    for node in model_graph.node_set:
        graph.add_node(node)

    for u, v in model_graph.edge_list:
        u_id = id(u)
        v_id = id(v)
        graph.add_edge(u_id, v_id)
        graph.nodes[u_id]['compute_node'] = u
        graph.nodes[v_id]['compute_node'] = v

    # Enrich each node with more info
    for n_id, data in graph.nodes(data=True):
        if 'compute_node' in data:
            n = data['compute_node']
            if hasattr(n, 'compute_unit_id'):
                if n.compute_unit_id in id_to_names:
                    layer_names = id_to_names[n.compute_unit_id]
                    if len(layer_names) == 1:
                        data['layer_name'] = layer_names[0]
                    else:
                        data['layer_names'] = layer_names[0]
            data['label'] = make_label(n, data)

    nx.write_network_text(graph, vertical_chains=1)
    # model_graph.visual_graph.view()

Produces:

╟── 139679377001936
╙── auxiliary-tensor
    ╽
    conv1:Conv2d:1
    ╽
    bn1:BatchNorm2d:2
    ╽
    ReLU:3
    ╽
    MaxPool2d:4
    ├─╼ layer1.0.conv1:Conv2d:5
    │   ╽
    │   layer1.0.bn1:BatchNorm2d:6
    │   ╽
    │   ReLU:7
    │   ╽
    │   layer1.0.conv2:Conv2d:8
    │   ╽
    │   layer1.0.bn2:BatchNorm2d:9
    │   ╽
    │   add_:10 ╾ MaxPool2d:4
    │   ╽
    │   ReLU:11
    │   ├─╼ layer1.1.conv1:Conv2d:12
    │   │   ╽
    │   │   layer1.1.bn1:BatchNorm2d:13
    │   │   ╽
    │   │   ReLU:14
    │   │   ╽
    │   │   layer1.1.conv2:Conv2d:15
    │   │   ╽
    │   │   layer1.1.bn2:BatchNorm2d:16
    │   │   ╽
    │   │   add_:17 ╾ ReLU:11
    │   │   ╽
    │   │   ReLU:18
    │   │   ├─╼ layer2.0.conv1:Conv2d:19
    │   │   │   ╽
    │   │   │   layer2.0.bn1:BatchNorm2d:20
    │   │   │   ╽
    │   │   │   ReLU:21
    │   │   │   ╽
    │   │   │   layer2.0.conv2:Conv2d:22
    │   │   │   ╽
    │   │   │   layer2.0.bn2:BatchNorm2d:23
    │   │   │   ╽
    │   │   │   add_:25 ╾ Sequential:24
    │   │   │   ╽
    │   │   │   ReLU:26
    │   │   │   ├─╼ layer2.1.conv1:Conv2d:27
    │   │   │   │   ╽
    │   │   │   │   layer2.1.bn1:BatchNorm2d:28
    │   │   │   │   ╽
    │   │   │   │   ReLU:29
    │   │   │   │   ╽
    │   │   │   │   layer2.1.conv2:Conv2d:30
    │   │   │   │   ╽
    │   │   │   │   layer2.1.bn2:BatchNorm2d:31
    │   │   │   │   ╽
    │   │   │   │   add_:32 ╾ ReLU:26
    │   │   │   │   ╽
    │   │   │   │   ReLU:33
    │   │   │   │   ├─╼ layer3.0.conv1:Conv2d:34
    │   │   │   │   │   ╽
    │   │   │   │   │   layer3.0.bn1:BatchNorm2d:35
    │   │   │   │   │   ╽
    │   │   │   │   │   ReLU:36
    │   │   │   │   │   ╽
    │   │   │   │   │   layer3.0.conv2:Conv2d:37
    │   │   │   │   │   ╽
    │   │   │   │   │   layer3.0.bn2:BatchNorm2d:38
    │   │   │   │   │   ╽
    │   │   │   │   │   add_:40 ╾ Sequential:39
    │   │   │   │   │   ╽
    │   │   │   │   │   ReLU:41
    │   │   │   │   │   ├─╼ layer3.1.conv1:Conv2d:42
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   layer3.1.bn1:BatchNorm2d:43
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   ReLU:44
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   layer3.1.conv2:Conv2d:45
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   layer3.1.bn2:BatchNorm2d:46
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   add_:47 ╾ ReLU:41
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   ReLU:48
    │   │   │   │   │   │   ├─╼ layer4.0.conv1:Conv2d:49
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   layer4.0.bn1:BatchNorm2d:50
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   ReLU:51
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   layer4.0.conv2:Conv2d:52
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   layer4.0.bn2:BatchNorm2d:53
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   add_:55 ╾ Sequential:54
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   ReLU:56
    │   │   │   │   │   │   │   ├─╼ layer4.1.conv1:Conv2d:57
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   layer4.1.bn1:BatchNorm2d:58
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   ReLU:59
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   layer4.1.conv2:Conv2d:60
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   layer4.1.bn2:BatchNorm2d:61
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   add_:62 ╾ ReLU:56
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   ReLU:63
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   AdaptiveAvgPool2d:64
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   flatten:65
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   fc:Linear:66
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   output-tensor:67
    │   │   │   │   │   │   │   └─╼  ...
    │   │   │   │   │   │   └─╼ Sequential:54
    │   │   │   │   │   │       └─╼  ...
    │   │   │   │   │   └─╼  ...
    │   │   │   │   └─╼ Sequential:39
    │   │   │   │       └─╼  ...
    │   │   │   └─╼  ...
    │   │   └─╼ Sequential:24
    │   │       └─╼  ...
    │   └─╼  ...
    └─╼  ...

You can see here that I've been able to associate many of the nodes with their original layer names. However, my solution to this is to just assume each instance is only used once. I think a correct solution would attempt to know which module attribute was the caller - which gets tricky if you assign multiple instances of a module to different variables.

To further process this into what I'm actually intersted in I do something like this:

    # Now that we have a graph where a subset of nodes correspond to known
    # layers, we can postprocess it to only show effective connections between
    # the layers.

    # Determine which nodes have associated layer names
    remove_ids = []
    keep_ids = []
    for n_id, data in graph.nodes(data=True):
        if 'layer_name' in data:
            keep_ids.append(n_id)
        else:
            remove_ids.append(n_id)

    import ubelt as ub
    topo_order = ub.OrderedSet(nx.topological_sort(graph))
    keep_topo_order = (topo_order & keep_ids)

    # Find the nearest ancestor that we want to view and collapse the node we
    # dont care about into it. Do a final relabeling to keep the original node
    # ids where possible.
    collapseables = defaultdict(list)
    for n in remove_ids:
        valid_prev_nodes = keep_topo_order & set(nx.ancestors(graph, n))
        if valid_prev_nodes:
            p = valid_prev_nodes[-1]
            collapseables[p].append(n)
    from networkx.algorithms.connectivity.edge_augmentation import collapse
    grouped_nodes = []
    for p, vs in collapseables.items():
        grouped_nodes.append([p, *vs])
    g2 = collapse(graph, grouped_nodes)
    relabel = {n: n for n in g2.nodes}
    new_to_olds = ub.udict(g2.graph['mapping']).invert(unique_vals=0)
    for new, olds in new_to_olds.items():
        if len(olds) == 1:
            old = ub.peek(olds)
            relabel[new] = old
        else:
            keep_olds = keep_topo_order & olds
            old = ub.peek(keep_olds)
            relabel[new] = old
    g3 = nx.relabel_nodes(g2, relabel)

    def transfer_data(g_dst, g_src):
        for n in set(g_dst.nodes) & set(g_src.nodes):
            g_dst.nodes[n].update(g_src.nodes[n])

    # Show the collapsed graph
    transfer_data(g3, graph)
    nx.write_network_text(g3, vertical_chains=1)

    # Further reduce the graph to remove skip connection information
    g4 = nx.transitive_reduction(g3)
    transfer_data(g4, graph)
    nx.write_network_text(g4, vertical_chains=1)

    g2 = nx.transitive_closure(graph)
    g2 = nx.transitive_reduction(g2)
    transfer_data(g2, graph)

Which shows the graph where the intermediate functional nodes have been collapsed into one of their parent layers:

╟── auxiliary-tensor
╎   ╽
╎   conv1:Conv2d:1
╎   ╽
╎   bn1:BatchNorm2d:2
╎   ├─╼ layer1.0.conv1:Conv2d:5
╎   │   ╽
╎   │   layer1.0.bn1:BatchNorm2d:6
╎   │   ╽
╎   │   layer1.0.conv2:Conv2d:8
╎   │   ╽
╎   │   layer1.0.bn2:BatchNorm2d:9 ╾ bn1:BatchNorm2d:2
╎   │   ├─╼ layer1.1.conv1:Conv2d:12
╎   │   │   ╽
╎   │   │   layer1.1.bn1:BatchNorm2d:13
╎   │   │   ╽
╎   │   │   layer1.1.conv2:Conv2d:15
╎   │   │   ╽
╎   │   │   layer1.1.bn2:BatchNorm2d:16 ╾ layer1.0.bn2:BatchNorm2d:9
╎   │   │   ├─╼ layer2.0.bn2:BatchNorm2d:23 ╾ layer2.0.conv2:Conv2d:22
╎   │   │   │   ├─╼ layer2.1.conv1:Conv2d:27
╎   │   │   │   │   ╽
╎   │   │   │   │   layer2.1.bn1:BatchNorm2d:28
╎   │   │   │   │   ╽
╎   │   │   │   │   layer2.1.conv2:Conv2d:30
╎   │   │   │   │   ╽
╎   │   │   │   │   layer2.1.bn2:BatchNorm2d:31 ╾ layer2.0.bn2:BatchNorm2d:23
╎   │   │   │   │   ├─╼ layer3.0.bn2:BatchNorm2d:38 ╾ layer3.0.conv2:Conv2d:37
╎   │   │   │   │   │   ├─╼ layer3.1.conv1:Conv2d:42
╎   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   layer3.1.bn1:BatchNorm2d:43
╎   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   layer3.1.conv2:Conv2d:45
╎   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   layer3.1.bn2:BatchNorm2d:46 ╾ layer3.0.bn2:BatchNorm2d:38
╎   │   │   │   │   │   │   ├─╼ layer4.0.bn2:BatchNorm2d:53 ╾ layer4.0.conv2:Conv2d:52
╎   │   │   │   │   │   │   │   ├─╼ layer4.1.conv1:Conv2d:57
╎   │   │   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   │   │   layer4.1.bn1:BatchNorm2d:58
╎   │   │   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   │   │   layer4.1.conv2:Conv2d:60
╎   │   │   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   │   │   layer4.1.bn2:BatchNorm2d:61 ╾ layer4.0.bn2:BatchNorm2d:53
╎   │   │   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   │   │   fc:Linear:66
╎   │   │   │   │   │   │   │   └─╼  ...
╎   │   │   │   │   │   │   └─╼ layer4.0.conv1:Conv2d:49
╎   │   │   │   │   │   │       ╽
╎   │   │   │   │   │   │       layer4.0.bn1:BatchNorm2d:50
╎   │   │   │   │   │   │       ╽
╎   │   │   │   │   │   │       layer4.0.conv2:Conv2d:52
╎   │   │   │   │   │   │       └─╼  ...
╎   │   │   │   │   │   └─╼  ...
╎   │   │   │   │   └─╼ layer3.0.conv1:Conv2d:34
╎   │   │   │   │       ╽
╎   │   │   │   │       layer3.0.bn1:BatchNorm2d:35
╎   │   │   │   │       ╽
╎   │   │   │   │       layer3.0.conv2:Conv2d:37
╎   │   │   │   │       └─╼  ...
╎   │   │   │   └─╼  ...
╎   │   │   └─╼ layer2.0.conv1:Conv2d:19
╎   │   │       ╽
╎   │   │       layer2.0.bn1:BatchNorm2d:20
╎   │   │       ╽
╎   │   │       layer2.0.conv2:Conv2d:22
╎   │   │       └─╼  ...
╎   │   └─╼  ...
╎   └─╼  ...
╙── 139679377001936

and finally what I ultimately want to see: the transitive reduction of this graph:

╟── auxiliary-tensor
╎   ╽
╎   conv1:Conv2d:1
╎   ╽
╎   bn1:BatchNorm2d:2
╎   ╽
╎   layer1.0.conv1:Conv2d:5
╎   ╽
╎   layer1.0.bn1:BatchNorm2d:6
╎   ╽
╎   layer1.0.conv2:Conv2d:8
╎   ╽
╎   layer1.0.bn2:BatchNorm2d:9
╎   ╽
╎   layer1.1.conv1:Conv2d:12
╎   ╽
╎   layer1.1.bn1:BatchNorm2d:13
╎   ╽
╎   layer1.1.conv2:Conv2d:15
╎   ╽
╎   layer1.1.bn2:BatchNorm2d:16
╎   ╽
╎   layer2.0.conv1:Conv2d:19
╎   ╽
╎   layer2.0.bn1:BatchNorm2d:20
╎   ╽
╎   layer2.0.conv2:Conv2d:22
╎   ╽
╎   layer2.0.bn2:BatchNorm2d:23
╎   ╽
╎   layer2.1.conv1:Conv2d:27
╎   ╽
╎   layer2.1.bn1:BatchNorm2d:28
╎   ╽
╎   layer2.1.conv2:Conv2d:30
╎   ╽
╎   layer2.1.bn2:BatchNorm2d:31
╎   ╽
╎   layer3.0.conv1:Conv2d:34
╎   ╽
╎   layer3.0.bn1:BatchNorm2d:35
╎   ╽
╎   layer3.0.conv2:Conv2d:37
╎   ╽
╎   layer3.0.bn2:BatchNorm2d:38
╎   ╽
╎   layer3.1.conv1:Conv2d:42
╎   ╽
╎   layer3.1.bn1:BatchNorm2d:43
╎   ╽
╎   layer3.1.conv2:Conv2d:45
╎   ╽
╎   layer3.1.bn2:BatchNorm2d:46
╎   ╽
╎   layer4.0.conv1:Conv2d:49
╎   ╽
╎   layer4.0.bn1:BatchNorm2d:50
╎   ╽
╎   layer4.0.conv2:Conv2d:52
╎   ╽
╎   layer4.0.bn2:BatchNorm2d:53
╎   ╽
╎   layer4.1.conv1:Conv2d:57
╎   ╽
╎   layer4.1.bn1:BatchNorm2d:58
╎   ╽
╎   layer4.1.conv2:Conv2d:60
╎   ╽
╎   layer4.1.bn2:BatchNorm2d:61
╎   ╽
╎   fc:Linear:66
╙── 139679377001936