szagoruyko / pytorchviz

A small package to create visualizations of PyTorch execution graphs
MIT License
3.18k stars 279 forks source link

add color themes - drafted #84

Open johndpope opened 6 months ago

johndpope commented 6 months ago

Firstly - this is great library. I added to some ML library to illustrate the attention mechanism and was blown away. https://github.com/xmu-xiaoma666/External-Attention-pytorch/issues/115

Looking at the gray colors - look a bit tired. I banged in the dot class into chatgpt and asked it to add colour theming from third party library. - there's 4 or 5 off the shelf libraries that can handle this. the palattable by @jiffyclub looked fine. https://github.com/jiffyclub/palettable

ChatGPT spat out this code to upgrade to support.

Would add 300kb - but a small price to pay for clarity.


from collections import namedtuple
from distutils.version import LooseVersion
from graphviz import Digraph
import torch
from torch.autograd import Variable
import warnings
import palettable

# Use a color palette from Palettable
palette = palettable.colorbrewer.qualitative.Set1_7.mpl_colors

def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50):
    """ Produces Graphviz representation of PyTorch autograd graph using Palettable for colors.
    """
    if params is not None:
        assert all(isinstance(p, Variable) for p in params.values())
        param_map = {id(v): k for k, v in params.items()}
    else:
        param_map = {}

    node_attr = dict(style='filled', shape='box', align='left', fontsize='10', ranksep='0.1', height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def color_to_hex(color):
        return '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))

    # Define color scheme
    colors = {
        'tensor': color_to_hex(palette[0]),
        'operation': color_to_hex(palette[1]),
        'saved_tensor': color_to_hex(palette[2]),
        'param_tensor': color_to_hex(palette[3]),
        'base_tensor': color_to_hex(palette[4]),
        'view_tensor': color_to_hex(palette[5])
    }

    def get_var_name(var, name=None):
        if not name:
            name = param_map[id(var)] if id(var) in param_map else ''
        return f'{name}\n {var.size()}'

    def add_nodes(fn):
        if fn in seen:
            return
        seen.add(fn)

        # Add nodes for saved tensors
        if show_saved and hasattr(fn, 'saved_tensors'):
            for t in fn.saved_tensors:
                if t not in seen:
                    seen.add(t)
                    dot.node(str(id(t)), get_var_name(t), fillcolor=colors['saved_tensor'])
                    dot.edge(str(id(t)), str(id(fn)), dir="none")

        # Add the node for this grad_fn
        fn_name = str(type(fn).__name__)
        dot.node(str(id(fn)), fn_name, fillcolor=colors['operation'])

        # Recurse for next functions
        if hasattr(fn, 'next_functions'):
            for u in fn.next_functions:
                if u[0] is not None:
                    dot.edge(str(id(u[0])), str(id(fn)))
                    add_nodes(u[0])

    def add_base_tensor(var):
        if var in seen:
            return
        seen.add(var)

        color = colors['base_tensor'] if var._is_view() else colors['tensor']
        dot.node(str(id(var)), get_var_name(var), fillcolor=color)

        if var.grad_fn:
            add_nodes(var.grad_fn)
            dot.edge(str(id(var.grad_fn)), str(id(var)))

        if var._is_view():
            base_var = var._base
            add_base_tensor(base_var)
            dot.edge(str(id(base_var)), str(id(var)), style="dotted", fillcolor=colors['view_tensor'])

    # handle multiple outputs
    if isinstance(var, tuple):
        for v in var:
            add_base_tensor(v)
    else:
        add_base_tensor(var)

    resize_graph(dot)

    return dot

def resize_graph(dot, size_per_element=0.15, min_size=12):
    num_rows = len(dot.body)
    content_size = num_rows * size_per_element
    size = max(min_size, content_size)
    size_str = str(size) + "," + str(size)
    dot.graph_attr.update(size=size_str)

UPDATE maybe better just to use google theme - dont need 10 million styles. just rip off their blue + green + font. and hey presto.


<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg width="901pt" height="1998pt" viewBox="0.00 0.00 901.00 1998.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1994)">
<title>%3</title>
<polygon fill="white" stroke="transparent" points="-4,4 -4,-1994 897,-1994 897,4 -4,4"/>

<!-- Example of a node in Google's blue -->
<g id="node1" class="node">
<title>137319056510800</title>
<polygon fill="#4285F4" stroke="black" points="249,-31 148,-31 148,0 249,0 249,-31"/>
<text text-anchor="middle" x="198.5" y="-7" font-family="Arial, Helvetica, sans-serif" font-size="10.00" fill="white"> (50, 64, 512)</text>
</g>

<!-- Example of an edge with Google's style -->
<g id="edge80" class="edge">
<title>137319055447216&#45;&gt;137319056510800</title>
<path fill="none" stroke="#34A853" d="M150.92,-72.73C158.26,-64.06 169.79,-50.43 179.66,-38.76"/>
<polygon fill="#34A853" stroke="#34A853" points="182.37,-40.98 186.16,-31.08 177.03,-36.46 182.37,-40.98"/>
</g>

<!-- Add more nodes and edges here with similar styles -->
</g>
</svg>

Screenshot from 2024-03-27 14-35-37

leo-ware commented 5 days ago

@johndpope I forked the repo and republished as torchviz2. Do you want to open a pull request on the new repo?

https://github.com/leo-ware/torchviz2

johndpope commented 5 days ago

thx for update - feel free to cherry pick whatever.