johnmarktaylor91 / torchlens

Package for extracting and mapping the results of every single tensor operation in a PyTorch model in one line of code.
GNU General Public License v3.0
454 stars 16 forks source link

Feature request: callback to control which nodes are expanded/collapsed #24

Open kalekundert opened 3 weeks ago

kalekundert commented 3 weeks ago

When using torchlens to visualize big models, I often wish there was an easier way to hide all of the elementary operations for certain layers. Later on I'll propose an API that would allow this, but I want to start by giving a motivating example. I was just working with one of the U-Net models from https://github.com/lucidrains/denoising-diffusion-pytorch, and my goal was to see how the size of the latent representation changes throughout the model.

Here's the visualization that torchlens produces for this model. You can see that it doesn't make it easy to track the size of the latent representation. Most of the complexity comes from the fact that the model contains lots of ResnetBlock and LinearAttention blocks, which each contain a lot of internal complexity. However, neither of these blocks changes the size of its input, so for the purpose of tracking sizes, whatever happens within them is unimportant. If these blocks were each represented as a single node, the whole graph would be much easier to understand.

I'm aware of the existing vis_nesting_depth option, but I don't think it satisfactorily addresses this issue. First, not all of the ResnetBlock and LinearAttention block are necessarily at the same depth, so this setting isn't always capable of collapsing only the nodes I want. Second, it's not easy to know what the depth of each block is, especially in a big model with lots of residual connections. To find a good visualization, you basically have to guess-and-check different depth cutoffs (and be careful to check that nothing important was collapsed).

I think a better API would be to modify ModelHistory.render_graph() to accept a function that will be called for each node, and return a boolean value indicating whether or not that node should be collapsed or expanded. The signature of this function might look something like this:

show_subnodes(node: TensorLogEntry) -> bool

This would allow the user to determine which nodes to collapse based on any property of those nodes. I'm of course interested in doing this based on the name/class of the corresponding layer. But I could imagine this also being useful for expanding only nodes that use lots of memory, or take a long time to evaluate. This API could also replace the existing API, since it also allows for collapsing based on depth (assuming that the nodes have some sort of depth attribute), but it's probably not worth breaking backwards compatibility over.


A more aggressive version of this API might be to instead have the callback control all aspects of node formatting, e.g.:

format_node(node: TensorLogEntry) -> dict

The dictionary returned by this function would describe how to format the node. Certain keys like expand would be specially extracted and interpreted by torchlens. Any others would just be passed along to graphviz (e.g. color, style, shape, etc.). This seems very elegant to me, but it's definitely a bigger change, and there could be practicalities that I'm overlooking that would complicate things.

johnmarktaylor91 commented 3 weeks ago

Thanks for this great suggestion, and for describing in such detail! I hope TorchLens has been helpful for you. Responses:

_I think a better API would be to modify ModelHistory.rendergraph() to accept a function that will be called for each node, and return a boolean value indicating whether or not that node should be collapsed or expanded.

This sounds like just the right interface and I like it a lot. There is just one complication: currently in TorchLens the “first class objects” are tensor operations, not modules, where modules are just the “containers” where the operations happen. So, right now the information about modules is not as easy to fetch as the information about tensor operations. It has been on my to-do list to log the module information more nicely. But, with your comment I’ll bump it to the top of my list.

A more aggressive version of this API might be to instead have the callback control all aspects of node formatting, e.g.:

This already exists actually, just added a few weeks ago. The CoLab has some examples, but you can pass in a dictionary of functions to override the visualization options according to different aspects of the model (e.g., colorizing based on the runtime or storage of the model). But, this doesn’t include specifying how to collapse nodes as you suggest above, which would be a great addition.

Thanks again for the feedback, it helps a lot to know people’s use-cases so I know what to focus on.

johnmarktaylor91 commented 3 weeks ago

Here's a minimal example of the custom visualization functionality, btw:

model = torchvision.models.googlenet()
x = torch.randn(1, 3, 224, 224)

# Define a function that returns a color based on the type of layer.

def colorize_layertype(model_history, layer):
    if 'conv' in layer.layer_type:
        return 'red'
    elif 'pool' in layer.layer_type:
        return 'blue'
    elif 'relu' in layer.layer_type:
        return 'green'
    elif 'linear' in layer.layer_type:
        return 'purple'
    elif 'dropout' in layer.layer_type:
        return 'orange'
    else:
        return 'black'

'''
Dictionary overriding how to specify each of the node
arguments in graphviz. Provide a constant value (e.g., "circle")
if you want that value to be fixed. If you want the value to be computed
based on some aspect of the layer, provide a function (as above)
that takes model_history and layer as arguments, and returns
the desired value for that argument.
 '''

vis_node_overrides = {'label': '',
                      'shape': 'circle',
                      'fixedsize': 'true',
                      'fillcolor': colorize_layertype}

# Override graph label
vis_graph_overrides = {'label': 'Colored by Layer Type'}

# Hide the nested module boxes by turning them white and removing the text.
vis_module_overrides = {'pencolor': 'white',
                        'label': ''}

tl.show_model_graph(model, x, vis_node_overrides=vis_node_overrides, vis_graph_overrides=vis_graph_overrides,
                    vis_module_overrides=vis_module_overrides)
image
kalekundert commented 3 weeks ago

Thanks for such a quick reply!

There is just one complication: currently in TorchLens the “first class objects” are tensor operations, not modules, where modules are just the “containers” where the operations happen.

Yeah, I'm not surprised that there are some practical considerations that I didn't appreciate. Let me know if you'd be open to a PR for this, or if you'd rather do it yourself along with some more extensive refactoring. I'm not guaranteeing that I'll be able to get to a PR any time soon, or at all, but this project has been very helpful to me and I like the idea of giving something back.

Here's a minimal example of the custom visualization functionality, btw:

That's definitely good to know. It's interesting to me how the current implementation is similar to but different than my "aggressive" proposal. The former is a dictionary with values that may be functions, and the latter is a function that returns a dictionary. Basically the difference is just whether the dictionary or the function is the "outermost" entity.

This ship has probably already sailed, but I have to at least comment that the function-returns-dict API seems slightly better to me. I think it would make it easier to calculate multiple related attributes. For example, you might want to control whether the text color is white or black based on the fill color. (Here's some code I wrote to do exactly this in one of my projects. It's a standard algorithm, but I find it rather interesting.) It'd be easiest to write a single function that calculates both colors. With the current API, you'd basically have to calculate the fill color twice.

That said, probably the best way to integrate the expand/collapse function into the current API would be to add an "expand" key to the vis_module_overrides dictionary. That would be a very small change to the API, and a very logical grouping of features; definitely better than my original idea of adding a whole new argument.

johnmarktaylor91 commented 3 weeks ago

That’s a very good point regarding “dict of functions” vs. “function that returns a dict.” You’re totally right about the potential interdependence of options. I just added this feature recently so I don’t think it’s too entrenched to fix. Lemme think about this.

And, thanks so much for the offer to do a pull request—so glad to hear TorchLens has helped you out! I think this refactor might make more sense for me to handle, since it’s going to involve restructuring how modules are treated in the code (I’ll probably make a ModuleLog class or some such thing with all the module info). But I’ll bump to the top of the to-do list.