mert-kurttutan / torchview

torchview: visualize pytorch models
https://torchview.dev
MIT License
822 stars 37 forks source link

possibility to add names of submodules of the graph, with their inputs and outputs names #97

Open ammoramine opened 1 year ago

ammoramine commented 1 year ago

Hello , first of all, thanks for starting this very nice project.

Is your feature request related to a problem? Please describe. Everything is in the title, the feature concerns the possibility to add the name of submodules on the graph, it eases the visualization of the graph, and is particularly interesting, when the graph possess multiples layers of the same kind at a certain level. Additionnaly, on the side of the inputs and outputs shapes, theirs name would be shown.

Describe the solution you'd like

In terms of api , the draw_graph, would take an additionnal arguments (like maybe showNames), that can be a boolean, show all names when equal to True, show no names when equal to False, or a dictionary, that enables for visualization of names of specific layers, by types, names, and location within the graph.

Describe alternatives you've considered No alternative, I am aware of, unless I miunderstood the current api.

Screenshots / Text Here is an example of code to visualize GPT2, using your framework:

!pip install transformers
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)

import torchview
res = torchview.draw_graph(model,input_data=encoded_input,roll=True,depth=1)
res.visual_graph.render("model_depth_1")

we can't know that input-tensor, in the top left of the graph , is "attention_mask" unless by guessing , same for the token embedding and positional embedding at the right. adding the names of submodules ,along with names of inputs and outputs, would ease the visualization.

The inspect framework, of standard library of python,could help for this purpose.