(著)山たー
使うか分からないが、忘れないようにメモ。ニューラルネットワークを組む時にargmaxを使いたい時があるが(例えばヒートマップにおける最大値の座標を出すとき)、そのままだと誤差伝搬できない。そこで、条件を緩和したsoftargmaxを用いる場合がある。これはsoftmax関数を少しいじったものである。
数式での定義
まず、通常のsoftmaxは $$ \text{softmax}(x)=\frac{\exp(x_i)}{\sum_j \exp(x_j)} $$ となる。実際に使う場合はオーバーフローを防ぐため、 $$ \text{softmax}(x)=\frac{\exp(x_i-\max(x))}{\sum_j \exp(x_j-\max(x))} $$
という式を用いる。ただし、$\max(x)$は配列$x$における最大値である。
一方で、softargmaxは $$ \text{softargmax}(x)=\sum_i \frac{\exp(\beta x_i-\max(\beta x))}{\sum_j \exp(\beta x_j-\max(\beta x))}\cdot i $$ となる。$\beta$は大きな数(100以上あれば十分か?)である。
一方で、softargmaxは $$ \text{softargmax}(x)=\sum_i \frac{\exp(\beta x_i-\max(\beta x))}{\sum_j \exp(\beta x_j-\max(\beta x))}\cdot i $$ となる。$\beta$は大きな数(100以上あれば十分か?)である。
2次元の場合での実装
2次元の場合で実装してみた。
import numpy as np
def softmax(x, beta=10):
c = np.max(beta*x)
ex = np.exp(beta*x - c)
sum_ex = np.sum(ex)
return ex / sum_ex
def softargmax_2D(x):
assert x.ndim == 2, "x dim must be 2"
s = softmax(x)
x = np.arange(x.shape[1])
y = np.arange(x.shape[0])
xx, yy = np.meshgrid(x, y)
xmax = np.sum(s*xx)
ymax = np.sum(s*yy)
return ymax, xmax
# toy array
arr = np.random.randint(0, 100, (4, 4))
print("array\n", arr)
# argmax
argmax = np.argwhere(arr.max() == arr)
print("\nargmax:", argmax[0,0], argmax[0,1])
# softargmax
ymax, xmax = softargmax_2D(arr)
print("softargmax: (float)", ymax, xmax)
print("softargmax: (round) ",
int(round(ymax)), int(round(xmax)))
結果は次のようになった。
array
[[13 73 41 82]
[44 69 29 26]
[37 3 14 43]
[38 50 9 45]]
argmax: 0 3
softargmax: (float) 3.4811068399043105e-57 3.0
softargmax: (round) 0 3ヒートマップに対しての使用例
ヒートマップを適当に生成して、argmaxとsoftargmaxを比較してみた。
ヒートマップ
argmax (最大点を白丸で表示)
softmax
※ディスプレイについてる白いゴミにしか見えないが、最大点を中心としたデルタ関数っぽいピークが存在している。
softargmax (最大点を白丸で表示)
import cv2
import numpy as np
def GaussianMask(sizex,sizey, sigma=10, center=None,fix=1):
"""
sizex : mask width
sizey : mask height
sigma : gaussian Sd
center : gaussian mean
fix : gaussian max
return gaussian mask
"""
x = np.arange(0, sizex, 1, float)
y = np.arange(0, sizey, 1, float)
x, y = np.meshgrid(x,y)
if center is None:
x0 = sizex // 2
y0 = sizey // 2
else:
if np.isnan(center[0])==False and np.isnan(center[1])==False:
x0 = center[0]
y0 = center[1]
else:
return np.zeros((sizey,sizex))
return fix*np.exp(-4*np.log(2) * ((x-x0)**2 + (y-y0)**2) / sigma**2)
def Pos2Densemap(fix_arr, width, height):
heatmap = np.zeros((H,W), np.float32)
for n_subject in range(fix_arr.shape[0]):
heatmap += GaussianMask(W, H, 33, (fix_arr[n_subject,0],fix_arr[n_subject,1]),
fix_arr[n_subject,2])
# Normalization
heatmap = heatmap/np.amax(heatmap)
heatmap = heatmap*255
heatmap = heatmap.astype("uint8")
#heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
return heatmap
def softmax(x, beta=100):
x = x/np.amax(x)
c = np.max(beta*x)
ex = np.exp(beta*x - c)
sum_ex = np.sum(ex)
return ex / sum_ex
def softargmax_2D(x):
assert x.ndim == 2, "x dim must be 2"
s = softmax(x)
x = np.arange(x.shape[1])
y = np.arange(x.shape[0])
xx, yy = np.meshgrid(x, y)
xmax = np.sum(s*xx)
ymax = np.sum(s*yy)
return ymax, xmax
if __name__ == '__main__':
# Generate toy data
num_subjects = 40
H, W = 256, 256
fix_arr = np.random.randn(num_subjects,3)
fix_arr -= fix_arr.min()
fix_arr /= fix_arr.max()
fix_arr[:,0] *= W
fix_arr[:,1] *= H
# Create heatmap
heatmap = Pos2Densemap(fix_arr, W, H)
heatmap_c = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
cv2.imshow('heatmap', heatmap_c)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.imwrite("heatmap.png", heatmap_c)
# argmax
argmax = np.argwhere(heatmap.max() == heatmap)
heatmap_argmax = cv2.circle(heatmap_c, (argmax[0,1], argmax[0,0]), 2, (255, 255, 255), -1)
cv2.imshow('argmax', heatmap_argmax)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.imwrite("argmax.png", heatmap_argmax)
# softmax
heatmap_softmax = softmax(heatmap)*255
heatmap_softmax = heatmap_softmax.astype("uint8")
heatmap_softmax = cv2.applyColorMap(heatmap_softmax, cv2.COLORMAP_JET)
cv2.imshow('softmax', heatmap_softmax)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.imwrite("softmax.png", heatmap_softmax)
# softargmax
ymax, xmax = softargmax_2D(heatmap)
heatmap_softargmax = cv2.circle(heatmap_c, (int(np.round(xmax)), int(np.round(ymax))), 2, (255, 255, 255), -1)
cv2.imshow('softargmax', heatmap_softargmax)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.imwrite("softargmax.png", heatmap_softargmax)
まあ、何とかなってる。
Chainerとかで使う場合にはFunctionを組み合わせるか、chainer.functionを新しく実装するかのいずれかが必要。
コメントをお書きください