パレート解を評価するのに、自己組織化マップ(SOM)という手法を使うことが多い。これは、多目的関数の結果を可視化するのに便利なのと、多数のパレート解をグルーピングするのに使えるためらしい。SOMのアルゴリズム自体は、いくつかバリエーションがあるようだが、基本のものはそんなに難しくないようだし、python bindingのものもいくつかあるようだ(e.g., SOMPY)。
で、関連論文なんかを読んでいるとSOMの結果を見せるのに、6角形のHexMapを使うことが多いようで、matplotlibで書けないか調べてみたら、やっぱりあった(笑)。
module - Hexagonal Self-Organizing map in Python - Stack Overflow
手元のCentOS6.6の標準のmatplotlibではエラーになるけど、CentOS7系とかの新しいものであれば大丈夫なはずです。あと、Hexの中心点の計算を行う関数があると良いかもしれないです。
import matplotlib.pyplot as plt
from matplotlib import *
from matplotlib.collections import RegularPolyCollection
from mpl_toolkits.axes_grid import make_axes_locatable
import math
import numpy as np
def plot_map(grid,
data,
w=1080,
dpi=72.,
title='SOM Hit map'):
"""
Args:
- grid: Grid dictionary (keys: hex_centers, width, height ),
- data: data array contaning the distances between each neuron
- w: width of the map in inches
- title: map title
Returns the Matplotlib SubAxis instance
"""
centersXY = grid['hex_centers']
width = grid['width']
height = grid['height']
# Size of figure in inches ( they are scaled as equal aspect ratio)
winch = (width * w / height ) / dpi
hinch = (height * w / width) / dpi
# make figure
fig = plt.figure(figsize=(winch, hinch), dpi=dpi, facecolor="white")
# add ax
ax = fig.add_subplot(111, aspect='equal')
# Get pixel size between to data points
Xpoints = centersXY[:, 0]
Ypoints = centersXY[:, 1]
ax.scatter(Xpoints, Ypoints, color="black", edgecolor="none", s=3, marker='o')
ax.axis([min(Xpoints)-1., max(Xpoints)+1.,
min(Ypoints)-1., max(Ypoints)+1.])
# make pixel data
xy_pixels = ax.transData.transform(np.vstack([Xpoints, Ypoints]).T)
xpix, ypix = xy_pixels.T
# In matplotlib, 0,0 is the lower left corner, whereas it's usually the
# upper right for most image software, so we'll flip the y-coords
ww, hh = fig.canvas.get_width_height()
ypix = hh - ypix
# discover radius and hexagon
apothem = 0.85 * (xpix[1]-xpix[0])/2.0 # 0.85 is spacing factor
area_inner_circle = math.pi*(apothem**2)
collection_bg = RegularPolyCollection(
numsides=6, # a hexagon
rotation=0, # rotation angle in radian
sizes=(area_inner_circle,),
facecolors = (1, 1, 1, 1), # white (not required but for safty)
edgecolors = (0, 0, 0, 1), # black
linewidths = (1,),
cmap = cm.rainbow,
offsets = centersXY,
transOffset = ax.transData,
array = data
)
ax.add_collection(collection_bg, autolim=True)
# make tics and titles (add here)
ax.axis('on')
ax.autoscale_view()
ax.set_title(title, y=-0.1, size=24)
for tick in ax.xaxis.get_major_ticks():
tick.label1On = False
tick.label2On = False
tick.tick1On = False
tick.tick2On = False
for tick in ax.yaxis.get_major_ticks():
tick.label1On = False
tick.label2On = False
tick.tick1On = False
tick.tick2On = False
#ax.scatter(Xpoints, Ypoints, color="black", edgecolor="none", s=3, marker='o')
# make color bar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.2)
cax.tick_params(labelsize=20)
cb = plt.colorbar(collection_bg, cax=cax)
cb.set_label(label='label', weight='bold', size=24)
return ax
###################################################
# make hex grid dictionary (center of hex, figure width, figure height)
grid = {'hex_centers': np.array([
[ 1.5 , 0.8660254 ],
[ 2.5 , 0.8660254 ],
[ 3.5 , 0.8660254 ],
[ 4.5 , 0.8660254 ],
[ 5.5 , 0.8660254 ],
[ 6.5 , 0.8660254 ],
[ 1. , 1.73205081],
[ 2. , 1.73205081],
[ 3. , 1.73205081],
[ 4. , 1.73205081],
[ 5. , 1.73205081],
[ 6. , 1.73205081],
[ 1.5 , 2.59807621],
[ 2.5 , 2.59807621],
[ 3.5 , 2.59807621],
[ 4.5 , 2.59807621],
[ 5.5 , 2.59807621],
[ 6.5 , 2.59807621],
[ 1. , 3.46410162],
[ 2. , 3.46410162],
[ 3. , 3.46410162],
[ 4. , 3.46410162],
[ 5. , 3.46410162],
[ 6. , 3.46410162],
[ 1.5 , 4.33012702],
[ 2.5 , 4.33012702],
[ 3.5 , 4.33012702],
[ 4.5 , 4.33012702],
[ 5.5 , 4.33012702],
[ 6.5 , 4.33012702],
[ 1. , 5.19615242],
[ 2. , 5.19615242],
[ 3. , 5.19615242],
[ 4. , 5.19615242],
[ 5. , 5.19615242],
[ 6. , 5.19615242]]),
'width': np.array([ 5.0 ]), # inch
'height': np.array([ 5.0 ])} # inch
# make data (data example : f(x,y) = x+y )
data = grid['hex_centers'][:,0]+grid["hex_centers"][:,1]
# make plot
ax = plot_map(grid,data)
# plot show
plt.show()