Cogito2012 / DEAR

[ICCV 2021 Oral] Deep Evidential Action Recognition
Apache License 2.0
118 stars 18 forks source link

About the visualization #15

Closed HaoliangZhou closed 2 months ago

HaoliangZhou commented 2 months ago

Very good paper! May I ask how to get the visualization of Fig. 4, please! 1

Cogito2012 commented 2 months ago

@HaoliangZhou Thanks for your interest! You may refer to the codes below, which are modified from visualizing-dirichlet-distributions:

import numpy as np
from math import gamma
from operator import mul
import matplotlib.pyplot as plt
import matplotlib.tri as tri

class Dirichlet(object):
    def __init__(self, alpha):
        self._alpha = np.array(alpha)
        self._coef = gamma(np.sum(self._alpha)) / np.multiply.reduce([gamma(a) for a in self._alpha])

    def pdf(self, x):
        '''Returns pdf value for `x`.'''
        return self._coef * np.multiply.reduce([xx ** (aa - 1) for (xx, aa)in zip(x, self._alpha)])

def xy2bc(xy, tol=1.e-4):
    '''Converts 2D Cartesian coordinates to barycentric.'''
    coords = np.array([tri_area(xy, p) for p in pairs]) / AREA
    return np.clip(coords, tol, 1.0 - tol)

def compute_pdf_contours(dist, subdiv=8):
    refiner = tri.UniformTriRefiner(triangle)
    trimesh = refiner.refine_triangulation(subdiv=subdiv)
    pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]
    return trimesh, pvals

def plot_simplex(trimesh, pvals, title=None, alpha=[1,1,1], nlevels=200, fontsize=15, **kwargs):
    plt.figure(figsize=(4,4))
    plt.rcParams["font.family"] = "Arial"
    plt.tricontourf(trimesh, pvals, nlevels, cmap='jet', **kwargs)
    plt.axis('equal')
    plt.xlim(0, 1)
    plt.ylim(0, 0.75**0.5)
    plt.axis('off')
    if title is not None:
        plt.text(0.5, -0.1, title, horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes, fontsize=fontsize)
        u = len(alpha) / np.sum(alpha)
        plt.text(0.5, -0.3, "$u=%.1f$"%(u), horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes, fontsize=fontsize)
    plt.tight_layout()

if __name__ == '__main__':
    corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
    AREA = 0.5 * 1 * 0.75**0.5
    triangle = tri.Triangulation(corners[:, 0], corners[:, 1])

    # For each corner of the triangle, the pair of other corners
    pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(3)]
    # The area of the triangle formed by point xy and another pair or points
    tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy)))
    fontsize=35

    trimesh, pvals = compute_pdf_contours(Dirichlet([12, 2, 4]))
    plot_simplex(trimesh, pvals)
    plt.savefig('simplex.png')

    # case 1: Accurate and Certain (AC)
    alpha = [10, 1.2, 1.2]
    trimesh, pvals = compute_pdf_contours(Dirichlet(alpha))
    plot_simplex(trimesh, pvals, "$\\alpha=[10, 1.2, 1.2]$", alpha, fontsize=fontsize)
    plt.savefig('simplex_AC.png')
    plt.savefig('simplex_AC.pdf')

    # case 2: Accurate and Uncertain (AU)
    alpha = [1.8, 1.2, 1.2]
    trimesh, pvals = compute_pdf_contours(Dirichlet(alpha))
    plot_simplex(trimesh, pvals, "$\\alpha=[1.8, 1.2, 1.2]$", alpha, fontsize=fontsize)
    plt.savefig('simplex_AU.png')
    plt.savefig('simplex_AU.pdf')

    # case 3: Inaccurate and Certain (IC)
    alpha = [10, 10, 10]
    trimesh, pvals = compute_pdf_contours(Dirichlet(alpha))
    plot_simplex(trimesh, pvals, "$\\alpha=[10, 10, 10]$", alpha, fontsize=fontsize)
    plt.savefig('simplex_IC.png')
    plt.savefig('simplex_IC.pdf')

    # case 4: Inaccurate and Uncertain (IU)
    alpha = [1.2, 1.2, 1.2]
    trimesh, pvals = compute_pdf_contours(Dirichlet(alpha))
    plot_simplex(trimesh, pvals, "$\\alpha=[1.2, 1.2, 1.2]$", alpha, fontsize=fontsize)
    plt.savefig('simplex_IU.png')
    plt.savefig('simplex_IU.pdf')
HaoliangZhou commented 2 months ago

Thanks a lot!