Looking at the gray colors - look a bit tired. I banged in the dot class into chatgpt and asked it to add colour theming from third party library. - there's 4 or 5 off the shelf libraries that can handle this.
the palattable by @jiffyclub looked fine. https://github.com/jiffyclub/palettable
ChatGPT spat out this code to upgrade to support.
Would add 300kb - but a small price to pay for clarity.
from collections import namedtuple
from distutils.version import LooseVersion
from graphviz import Digraph
import torch
from torch.autograd import Variable
import warnings
import palettable
# Use a color palette from Palettable
palette = palettable.colorbrewer.qualitative.Set1_7.mpl_colors
def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50):
""" Produces Graphviz representation of PyTorch autograd graph using Palettable for colors.
"""
if params is not None:
assert all(isinstance(p, Variable) for p in params.values())
param_map = {id(v): k for k, v in params.items()}
else:
param_map = {}
node_attr = dict(style='filled', shape='box', align='left', fontsize='10', ranksep='0.1', height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def color_to_hex(color):
return '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
# Define color scheme
colors = {
'tensor': color_to_hex(palette[0]),
'operation': color_to_hex(palette[1]),
'saved_tensor': color_to_hex(palette[2]),
'param_tensor': color_to_hex(palette[3]),
'base_tensor': color_to_hex(palette[4]),
'view_tensor': color_to_hex(palette[5])
}
def get_var_name(var, name=None):
if not name:
name = param_map[id(var)] if id(var) in param_map else ''
return f'{name}\n {var.size()}'
def add_nodes(fn):
if fn in seen:
return
seen.add(fn)
# Add nodes for saved tensors
if show_saved and hasattr(fn, 'saved_tensors'):
for t in fn.saved_tensors:
if t not in seen:
seen.add(t)
dot.node(str(id(t)), get_var_name(t), fillcolor=colors['saved_tensor'])
dot.edge(str(id(t)), str(id(fn)), dir="none")
# Add the node for this grad_fn
fn_name = str(type(fn).__name__)
dot.node(str(id(fn)), fn_name, fillcolor=colors['operation'])
# Recurse for next functions
if hasattr(fn, 'next_functions'):
for u in fn.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(fn)))
add_nodes(u[0])
def add_base_tensor(var):
if var in seen:
return
seen.add(var)
color = colors['base_tensor'] if var._is_view() else colors['tensor']
dot.node(str(id(var)), get_var_name(var), fillcolor=color)
if var.grad_fn:
add_nodes(var.grad_fn)
dot.edge(str(id(var.grad_fn)), str(id(var)))
if var._is_view():
base_var = var._base
add_base_tensor(base_var)
dot.edge(str(id(base_var)), str(id(var)), style="dotted", fillcolor=colors['view_tensor'])
# handle multiple outputs
if isinstance(var, tuple):
for v in var:
add_base_tensor(v)
else:
add_base_tensor(var)
resize_graph(dot)
return dot
def resize_graph(dot, size_per_element=0.15, min_size=12):
num_rows = len(dot.body)
content_size = num_rows * size_per_element
size = max(min_size, content_size)
size_str = str(size) + "," + str(size)
dot.graph_attr.update(size=size_str)
UPDATE
maybe better just to use google theme - dont need 10 million styles.
just rip off their blue + green + font.
and hey presto.
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg width="901pt" height="1998pt" viewBox="0.00 0.00 901.00 1998.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1994)">
<title>%3</title>
<polygon fill="white" stroke="transparent" points="-4,4 -4,-1994 897,-1994 897,4 -4,4"/>
<!-- Example of a node in Google's blue -->
<g id="node1" class="node">
<title>137319056510800</title>
<polygon fill="#4285F4" stroke="black" points="249,-31 148,-31 148,0 249,0 249,-31"/>
<text text-anchor="middle" x="198.5" y="-7" font-family="Arial, Helvetica, sans-serif" font-size="10.00" fill="white"> (50, 64, 512)</text>
</g>
<!-- Example of an edge with Google's style -->
<g id="edge80" class="edge">
<title>137319055447216->137319056510800</title>
<path fill="none" stroke="#34A853" d="M150.92,-72.73C158.26,-64.06 169.79,-50.43 179.66,-38.76"/>
<polygon fill="#34A853" stroke="#34A853" points="182.37,-40.98 186.16,-31.08 177.03,-36.46 182.37,-40.98"/>
</g>
<!-- Add more nodes and edges here with similar styles -->
</g>
</svg>
Firstly - this is great library. I added to some ML library to illustrate the attention mechanism and was blown away. https://github.com/xmu-xiaoma666/External-Attention-pytorch/issues/115
Looking at the gray colors - look a bit tired. I banged in the dot class into chatgpt and asked it to add colour theming from third party library. - there's 4 or 5 off the shelf libraries that can handle this. the palattable by @jiffyclub looked fine. https://github.com/jiffyclub/palettable
ChatGPT spat out this code to upgrade to support.
Would add 300kb - but a small price to pay for clarity.
UPDATE maybe better just to use google theme - dont need 10 million styles. just rip off their blue + green + font. and hey presto.