Open noklam opened 5 years ago
This would be a useful feature. It's not available right now, but you can probably do it with a bit of additional code. I'll tag this issue as an enhancement so hopefully we add it later, or maybe someone adds it and submits a pull request.
HiddenLayer stores the neural network as a graph (a series of nodes stored as a dict, and a series of edges stored as a list, in the Graph
class). You can easily manipulate that graph anyway you like. For example, you can delete all the nodes except the ones you want to print.
Take a look at the Prune
transform in transforms.py
. It's a simple transform (~15 lines of code) that removes all nodes that match a pattern. You need to do the opposite: delete all nodes that do not match the pattern. Let's call it the Subgraph
transform. If you write that, then you can use it as follows:
graph = hl.build_graph(pytorch_graph, inputs)
sub_graph = Subgraph(pattern_of_nodes_to_keep).apply(graph)
If you can't describe the nodes you want to keep with a graph expression pattern, then you can do it with node IDs (see the FoldId
transform for an example).
Thx! That's very useful guide. I will play with it see if I can come up with something.
The graph is useful for visualizing model, but when the model is big, printing the entire model is not too helpful. Say I already have a structure of the model(a OrderDict or groups of layer), is that possible to print only part of a model graph.
I am using PyTorch with a pre-trained Resnet, I am only interested in the layer after the Resnet Encoder part. Ideally, I would like to print something like ResNet(Maybe just a manual input argument) + details of the interested layer groups.