· 

numpyでsoftargmaxを実装してみた

(著)山たー

使うか分からないが、忘れないようにメモ。ニューラルネットワークを組む時にargmaxを使いたい時があるが(例えばヒートマップにおける最大値の座標を出すとき)、そのままだと誤差伝搬できない。そこで、条件を緩和したsoftargmaxを用いる場合がある。これはsoftmax関数を少しいじったものである。

 

数式での定義

まず、通常のsoftmaxは $$ \text{softmax}(x)=\frac{\exp(x_i)}{\sum_j \exp(x_j)} $$ となる。実際に使う場合はオーバーフローを防ぐため、 $$ \text{softmax}(x)=\frac{\exp(x_i-\max(x))}{\sum_j \exp(x_j-\max(x))} $$ という式を用いる。ただし、$\max(x)$は配列$x$における最大値である。

一方で、softargmaxは $$ \text{softargmax}(x)=\sum_i \frac{\exp(\beta x_i-\max(\beta x))}{\sum_j \exp(\beta x_j-\max(\beta x))}\cdot i $$ となる。$\beta$は大きな数(100以上あれば十分か?)である。

2次元の場合での実装

2次元の場合で実装してみた。

import numpy as np

def softmax(x, beta=10):
    c = np.max(beta*x)
    ex = np.exp(beta*x - c)
    sum_ex = np.sum(ex)
    return ex / sum_ex
    
def softargmax_2D(x):
    assert x.ndim == 2, "x dim must be 2"
    
    s = softmax(x)
    x = np.arange(x.shape[1])
    y = np.arange(x.shape[0])    
    xx, yy = np.meshgrid(x, y)
    xmax = np.sum(s*xx)
    ymax = np.sum(s*yy)
    return ymax, xmax

# toy array
arr = np.random.randint(0, 100, (4, 4))
print("array\n", arr)

# argmax
argmax = np.argwhere(arr.max() == arr)
print("\nargmax:", argmax[0,0], argmax[0,1])

# softargmax
ymax, xmax = softargmax_2D(arr)
print("softargmax: (float)", ymax, xmax)
print("softargmax: (round)  ",
      int(round(ymax)), int(round(xmax)))

 

結果は次のようになった。

array
 [[13 73 41 82]
 [44 69 29 26]
 [37  3 14 43]
 [38 50  9 45]]

argmax: 0 3
softargmax: (float) 3.4811068399043105e-57 3.0
softargmax: (round)   0 3

ヒートマップに対しての使用例

ヒートマップを適当に生成して、argmaxとsoftargmaxを比較してみた。

ヒートマップ

argmax (最大点を白丸で表示)

softmax

※ディスプレイについてる白いゴミにしか見えないが、最大点を中心としたデルタ関数っぽいピークが存在している。

 

softargmax (最大点を白丸で表示)

import cv2
import numpy as np

def GaussianMask(sizex,sizey, sigma=10, center=None,fix=1):
    """
    sizex  : mask width
    sizey  : mask height
    sigma  : gaussian Sd
    center : gaussian mean
    fix    : gaussian max
    return gaussian mask
    """
    x = np.arange(0, sizex, 1, float)
    y = np.arange(0, sizey, 1, float)
    x, y = np.meshgrid(x,y)
    
    if center is None:
        x0 = sizex // 2
        y0 = sizey // 2
    else:
        if np.isnan(center[0])==False and np.isnan(center[1])==False:            
            x0 = center[0]
            y0 = center[1]        
        else:
            return np.zeros((sizey,sizex))

    return fix*np.exp(-4*np.log(2) * ((x-x0)**2 + (y-y0)**2) / sigma**2)

def Pos2Densemap(fix_arr, width, height):   
    heatmap = np.zeros((H,W), np.float32)
    for n_subject in range(fix_arr.shape[0]):
        heatmap += GaussianMask(W, H, 33, (fix_arr[n_subject,0],fix_arr[n_subject,1]),
                                fix_arr[n_subject,2])

    # Normalization
    heatmap = heatmap/np.amax(heatmap)
    heatmap = heatmap*255
    heatmap = heatmap.astype("uint8")
    #heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    return heatmap

def softmax(x, beta=100):
    x = x/np.amax(x)
    c = np.max(beta*x)
    ex = np.exp(beta*x - c)
    sum_ex = np.sum(ex)
    return ex / sum_ex
    
def softargmax_2D(x):
    assert x.ndim == 2, "x dim must be 2"
    
    s = softmax(x)
    x = np.arange(x.shape[1])
    y = np.arange(x.shape[0])    
    xx, yy = np.meshgrid(x, y)
    xmax = np.sum(s*xx)
    ymax = np.sum(s*yy)
    return ymax, xmax

if __name__ == '__main__':
    # Generate toy data
    num_subjects = 40
    H, W = 256, 256
    
    fix_arr = np.random.randn(num_subjects,3)
    fix_arr -= fix_arr.min()
    fix_arr /= fix_arr.max()
    fix_arr[:,0] *= W
    fix_arr[:,1] *= H
    
    # Create heatmap
    heatmap = Pos2Densemap(fix_arr, W, H)
    heatmap_c = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    cv2.imshow('heatmap', heatmap_c)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    cv2.imwrite("heatmap.png", heatmap_c)
    
    # argmax
    argmax = np.argwhere(heatmap.max() == heatmap)
    heatmap_argmax = cv2.circle(heatmap_c, (argmax[0,1], argmax[0,0]), 2, (255, 255, 255), -1)
    cv2.imshow('argmax', heatmap_argmax)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    cv2.imwrite("argmax.png", heatmap_argmax)
    
    # softmax
    heatmap_softmax = softmax(heatmap)*255
    heatmap_softmax = heatmap_softmax.astype("uint8")
    heatmap_softmax = cv2.applyColorMap(heatmap_softmax, cv2.COLORMAP_JET)
    cv2.imshow('softmax', heatmap_softmax)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    cv2.imwrite("softmax.png", heatmap_softmax)
    
    # softargmax
    ymax, xmax = softargmax_2D(heatmap)
    heatmap_softargmax = cv2.circle(heatmap_c, (int(np.round(xmax)), int(np.round(ymax))), 2, (255, 255, 255), -1)
    cv2.imshow('softargmax', heatmap_softargmax)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    cv2.imwrite("softargmax.png", heatmap_softargmax)
        

まあ、何とかなってる。

 

Chainerとかで使う場合にはFunctionを組み合わせるか、chainer.functionを新しく実装するかのいずれかが必要。