waleedka / hiddenlayer

Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.
MIT License
1.79k stars 266 forks source link

Is it possible to control which part of the model graph is print #21

Open noklam opened 5 years ago

noklam commented 5 years ago

The graph is useful for visualizing model, but when the model is big, printing the entire model is not too helpful. Say I already have a structure of the model(a OrderDict or groups of layer), is that possible to print only part of a model graph.

I am using PyTorch with a pre-trained Resnet, I am only interested in the layer after the Resnet Encoder part. Ideally, I would like to print something like ResNet(Maybe just a manual input argument) + details of the interested layer groups.

waleedka commented 5 years ago

This would be a useful feature. It's not available right now, but you can probably do it with a bit of additional code. I'll tag this issue as an enhancement so hopefully we add it later, or maybe someone adds it and submits a pull request.

HiddenLayer stores the neural network as a graph (a series of nodes stored as a dict, and a series of edges stored as a list, in the Graph class). You can easily manipulate that graph anyway you like. For example, you can delete all the nodes except the ones you want to print.

Take a look at the Prune transform in transforms.py. It's a simple transform (~15 lines of code) that removes all nodes that match a pattern. You need to do the opposite: delete all nodes that do not match the pattern. Let's call it the Subgraph transform. If you write that, then you can use it as follows:

graph = hl.build_graph(pytorch_graph, inputs)
sub_graph = Subgraph(pattern_of_nodes_to_keep).apply(graph)

If you can't describe the nodes you want to keep with a graph expression pattern, then you can do it with node IDs (see the FoldId transform for an example).

noklam commented 5 years ago

Thx! That's very useful guide. I will play with it see if I can come up with something.