Closed HaoliangZhou closed 2 months ago
@HaoliangZhou Thanks for your interest! You may refer to the codes below, which are modified from visualizing-dirichlet-distributions:
import numpy as np
from math import gamma
from operator import mul
import matplotlib.pyplot as plt
import matplotlib.tri as tri
class Dirichlet(object):
def __init__(self, alpha):
self._alpha = np.array(alpha)
self._coef = gamma(np.sum(self._alpha)) / np.multiply.reduce([gamma(a) for a in self._alpha])
def pdf(self, x):
'''Returns pdf value for `x`.'''
return self._coef * np.multiply.reduce([xx ** (aa - 1) for (xx, aa)in zip(x, self._alpha)])
def xy2bc(xy, tol=1.e-4):
'''Converts 2D Cartesian coordinates to barycentric.'''
coords = np.array([tri_area(xy, p) for p in pairs]) / AREA
return np.clip(coords, tol, 1.0 - tol)
def compute_pdf_contours(dist, subdiv=8):
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=subdiv)
pvals = [dist.pdf(xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]
return trimesh, pvals
def plot_simplex(trimesh, pvals, title=None, alpha=[1,1,1], nlevels=200, fontsize=15, **kwargs):
plt.figure(figsize=(4,4))
plt.rcParams["font.family"] = "Arial"
plt.tricontourf(trimesh, pvals, nlevels, cmap='jet', **kwargs)
plt.axis('equal')
plt.xlim(0, 1)
plt.ylim(0, 0.75**0.5)
plt.axis('off')
if title is not None:
plt.text(0.5, -0.1, title, horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes, fontsize=fontsize)
u = len(alpha) / np.sum(alpha)
plt.text(0.5, -0.3, "$u=%.1f$"%(u), horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes, fontsize=fontsize)
plt.tight_layout()
if __name__ == '__main__':
corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
AREA = 0.5 * 1 * 0.75**0.5
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
# For each corner of the triangle, the pair of other corners
pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(3)]
# The area of the triangle formed by point xy and another pair or points
tri_area = lambda xy, pair: 0.5 * np.linalg.norm(np.cross(*(pair - xy)))
fontsize=35
trimesh, pvals = compute_pdf_contours(Dirichlet([12, 2, 4]))
plot_simplex(trimesh, pvals)
plt.savefig('simplex.png')
# case 1: Accurate and Certain (AC)
alpha = [10, 1.2, 1.2]
trimesh, pvals = compute_pdf_contours(Dirichlet(alpha))
plot_simplex(trimesh, pvals, "$\\alpha=[10, 1.2, 1.2]$", alpha, fontsize=fontsize)
plt.savefig('simplex_AC.png')
plt.savefig('simplex_AC.pdf')
# case 2: Accurate and Uncertain (AU)
alpha = [1.8, 1.2, 1.2]
trimesh, pvals = compute_pdf_contours(Dirichlet(alpha))
plot_simplex(trimesh, pvals, "$\\alpha=[1.8, 1.2, 1.2]$", alpha, fontsize=fontsize)
plt.savefig('simplex_AU.png')
plt.savefig('simplex_AU.pdf')
# case 3: Inaccurate and Certain (IC)
alpha = [10, 10, 10]
trimesh, pvals = compute_pdf_contours(Dirichlet(alpha))
plot_simplex(trimesh, pvals, "$\\alpha=[10, 10, 10]$", alpha, fontsize=fontsize)
plt.savefig('simplex_IC.png')
plt.savefig('simplex_IC.pdf')
# case 4: Inaccurate and Uncertain (IU)
alpha = [1.2, 1.2, 1.2]
trimesh, pvals = compute_pdf_contours(Dirichlet(alpha))
plot_simplex(trimesh, pvals, "$\\alpha=[1.2, 1.2, 1.2]$", alpha, fontsize=fontsize)
plt.savefig('simplex_IU.png')
plt.savefig('simplex_IU.pdf')
Thanks a lot!
Very good paper! May I ask how to get the visualization of Fig. 4, please!