oyamad / game_theory_models

Python code for game theory modeling
BSD 3-Clause "New" or "Revised" License
31 stars 9 forks source link

Animation #18

Open ghost opened 9 years ago

ghost commented 9 years ago
def animation(li, init_actions=None, pos='circular', node_size=1000, node_colors=None, linewidth=6, interval=1000):
    li.set_init_actions(init_actions)
    num_actions = li.num_actions

    if num_actions > 3:
        raise ValueError('The number of actions must be at most 3.')

    G = nx.Graph(li.adj_matrix)

    if isinstance(pos, dict):
        pos = pos
    elif pos == 'random':
        pos = nx.random_layout(G)
    elif pos == 'shell':
        pos = nx.shell_layout(G)
    elif pos == 'spring':
        pos = nx.spring_layout(G)
    elif pos == 'spectral':
        pos = nx.spectral_layout(G)
    else:
        pos = nx.circular_layout(G)

    if node_colors == None:
        colors = ['w', 'b', 'r']
        node_colors = []
        for i in range(num_actions):
            node_colors.append(colors[i])
    elif num_actions != len(node_colors):
        raise ValueError('The number of actions must correspond to the number of colors.')

    def get_fig(a):
        partition = []
        for i in range(num_actions):
            partition.append(np.where(li.current_actions == i)[0].tolist())
            nx.draw_networkx_nodes(G, pos, node_size=node_size, nodelist=partition[i], node_color=node_colors)
        nx.draw_networkx_edges(G, pos, alpha=0.5, width=linewidth)
        li.play()
        return fig

    fig = plt.figure()
    anim = FuncAnimation(fig, get_fig, interval=interval)
    plt.show()
oyamad commented 9 years ago

Update:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import networkx as nx
from localint import LocalInteraction

def animation(li, init_actions=None, pos='circular', node_size=200,
              node_colors=None, linewidth=2, interval=200, figsize=(16,10)):
    num_actions = li.num_actions

    if node_colors is None:
        node_colors = mpl.rcParams['axes.color_cycle']
    num_colors = len(node_colors)
    if num_colors < num_actions:
        raise ValueError('{0} colors required '.format(num_actions) +
                         '(only {0} provided)'.format(num_colors))

    G = nx.DiGraph(li.adj_matrix)

    if isinstance(pos, dict):
        pos = pos
    else:
        try:
            layout_func = getattr(nx, '{0}_layout'.format(pos))
            pos = layout_func(G)
        except:
            raise ValueError(
                "pos must be a dictionary of node-position pairs, or one of " +
                "{'circular', 'random', 'shell', 'spring', 'spectral'}")

    def get_fig(n):
        for i in range(num_actions):
            nodelist = np.where(li.current_actions == i)[0].tolist()
            nx.draw_networkx_nodes(G, pos, node_size=node_size,
                                   nodelist=nodelist,
                                   node_color=node_colors[i])
        li.play()
        return fig

    li.set_init_actions(init_actions)

    fig = plt.figure(figsize=figsize, facecolor='w')
    nx.draw_networkx_edges(G, pos, alpha=0.5, width=linewidth, arrows=False)
    anim = FuncAnimation(fig, get_fig, interval=interval)
    plt.axis('off')
    plt.show()
    plt.close()