Open ghost opened 9 years ago
Update:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import networkx as nx
from localint import LocalInteraction
def animation(li, init_actions=None, pos='circular', node_size=200,
node_colors=None, linewidth=2, interval=200, figsize=(16,10)):
num_actions = li.num_actions
if node_colors is None:
node_colors = mpl.rcParams['axes.color_cycle']
num_colors = len(node_colors)
if num_colors < num_actions:
raise ValueError('{0} colors required '.format(num_actions) +
'(only {0} provided)'.format(num_colors))
G = nx.DiGraph(li.adj_matrix)
if isinstance(pos, dict):
pos = pos
else:
try:
layout_func = getattr(nx, '{0}_layout'.format(pos))
pos = layout_func(G)
except:
raise ValueError(
"pos must be a dictionary of node-position pairs, or one of " +
"{'circular', 'random', 'shell', 'spring', 'spectral'}")
def get_fig(n):
for i in range(num_actions):
nodelist = np.where(li.current_actions == i)[0].tolist()
nx.draw_networkx_nodes(G, pos, node_size=node_size,
nodelist=nodelist,
node_color=node_colors[i])
li.play()
return fig
li.set_init_actions(init_actions)
fig = plt.figure(figsize=figsize, facecolor='w')
nx.draw_networkx_edges(G, pos, alpha=0.5, width=linewidth, arrows=False)
anim = FuncAnimation(fig, get_fig, interval=interval)
plt.axis('off')
plt.show()
plt.close()