いわて駐在研究日誌2。

NEVER STAND BEHIND ME

形状最適化(ちょっと脱線)

 パレート解を評価するのに、自己組織化マップ(SOM)という手法を使うことが多い。これは、多目的関数の結果を可視化するのに便利なのと、多数のパレート解をグルーピングするのに使えるためらしい。SOMのアルゴリズム自体は、いくつかバリエーションがあるようだが、基本のものはそんなに難しくないようだし、python bindingのものもいくつかあるようだ(e.g., SOMPY)。

 で、関連論文なんかを読んでいるとSOMの結果を見せるのに、6角形のHexMapを使うことが多いようで、matplotlibで書けないか調べてみたら、やっぱりあった(笑)。


module - Hexagonal Self-Organizing map in Python - Stack Overflow

 

 手元のCentOS6.6の標準のmatplotlibではエラーになるけど、CentOS7系とかの新しいものであれば大丈夫なはずです。あと、Hexの中心点の計算を行う関数があると良いかもしれないです。

import matplotlib.pyplot as plt
from matplotlib import *
from matplotlib.collections import RegularPolyCollection
from mpl_toolkits.axes_grid import make_axes_locatable
import math
import numpy as np

def plot_map(grid,
             data,
             w=1080,
             dpi=72.,
             title='SOM Hit map'):
  """
  Args:
  - grid: Grid dictionary (keys: hex_centers, width, height ),
  - data: data array contaning the distances between each neuron
  - w: width of the map in inches
  - title: map title
  Returns the Matplotlib SubAxis instance
  """
  centersXY = grid['hex_centers']
  width = grid['width']
  height = grid['height']

  # Size of figure in inches ( they are scaled as equal aspect ratio)
  winch = (width * w / height ) / dpi
  hinch = (height * w / width)  / dpi

  # make figure
  fig = plt.figure(figsize=(winch, hinch), dpi=dpi, facecolor="white")

  # add ax
  ax = fig.add_subplot(111, aspect='equal')

  # Get pixel size between to data points
  Xpoints = centersXY[:, 0]
  Ypoints = centersXY[:, 1]
  ax.scatter(Xpoints, Ypoints, color="black", edgecolor="none", s=3, marker='o')
  ax.axis([min(Xpoints)-1., max(Xpoints)+1.,
           min(Ypoints)-1., max(Ypoints)+1.])
 
  # make pixel data
  xy_pixels = ax.transData.transform(np.vstack([Xpoints, Ypoints]).T)
  xpix, ypix = xy_pixels.T

  # In matplotlib, 0,0 is the lower left corner, whereas it's usually the
  # upper right for most image software, so we'll flip the y-coords
  ww, hh = fig.canvas.get_width_height()
  ypix = hh - ypix
 
  # discover radius and hexagon
  apothem = 0.85 * (xpix[1]-xpix[0])/2.0 # 0.85 is spacing factor
  area_inner_circle = math.pi*(apothem**2)

  collection_bg = RegularPolyCollection(
    numsides=6,  # a hexagon
    rotation=0,  # rotation angle in radian
    sizes=(area_inner_circle,),
      facecolors = (1, 1, 1, 1), # white (not required but for safty)
      edgecolors = (0, 0, 0, 1), # black
      linewidths = (1,),
      cmap = cm.rainbow,
      offsets = centersXY,
      transOffset = ax.transData,
      array = data
  )
  ax.add_collection(collection_bg, autolim=True)
 
  # make tics and titles (add here)
  ax.axis('on')
  ax.autoscale_view()
  ax.set_title(title, y=-0.1, size=24)
  for tick in ax.xaxis.get_major_ticks():
    tick.label1On = False
    tick.label2On = False
    tick.tick1On = False
    tick.tick2On = False
  for tick in ax.yaxis.get_major_ticks():
    tick.label1On = False
    tick.label2On = False
    tick.tick1On = False
    tick.tick2On = False
 
  #ax.scatter(Xpoints, Ypoints, color="black", edgecolor="none", s=3, marker='o')

  # make color bar
  divider = make_axes_locatable(ax)
  cax = divider.append_axes("right", size="5%", pad=0.2)
  cax.tick_params(labelsize=20)
  cb = plt.colorbar(collection_bg, cax=cax)
  cb.set_label(label='label', weight='bold', size=24)

  return ax

###################################################

# make hex grid dictionary (center of hex, figure width, figure height)
grid = {'hex_centers': np.array([
   [ 1.5       ,  0.8660254 ],
   [ 2.5       ,  0.8660254 ],
   [ 3.5       ,  0.8660254 ],
   [ 4.5       ,  0.8660254 ],
   [ 5.5       ,  0.8660254 ],
   [ 6.5       ,  0.8660254 ],
   [ 1.        ,  1.73205081],
   [ 2.        ,  1.73205081],
   [ 3.        ,  1.73205081],
   [ 4.        ,  1.73205081],
   [ 5.        ,  1.73205081],
   [ 6.        ,  1.73205081],
   [ 1.5       ,  2.59807621],
   [ 2.5       ,  2.59807621],
   [ 3.5       ,  2.59807621],
   [ 4.5       ,  2.59807621],
   [ 5.5       ,  2.59807621],
   [ 6.5       ,  2.59807621],
   [ 1.        ,  3.46410162],
   [ 2.        ,  3.46410162],
   [ 3.        ,  3.46410162],
   [ 4.        ,  3.46410162],
   [ 5.        ,  3.46410162],
   [ 6.        ,  3.46410162],
   [ 1.5       ,  4.33012702],
   [ 2.5       ,  4.33012702],
   [ 3.5       ,  4.33012702],
   [ 4.5       ,  4.33012702],
   [ 5.5       ,  4.33012702],
   [ 6.5       ,  4.33012702],
   [ 1.        ,  5.19615242],
   [ 2.        ,  5.19615242],
   [ 3.        ,  5.19615242],
   [ 4.        ,  5.19615242],
   [ 5.        ,  5.19615242],
   [ 6.        ,  5.19615242]]),
'width': np.array([ 5.0 ]), # inch
'height': np.array([ 5.0 ])} # inch

# make data (data example :  f(x,y) = x+y )
data = grid['hex_centers'][:,0]+grid["hex_centers"][:,1]

# make plot
ax = plot_map(grid,data)

# plot show
plt.show()