(著)山たー
使うか分からないが、忘れないようにメモ。ニューラルネットワークを組む時にargmaxを使いたい時があるが(例えばヒートマップにおける最大値の座標を出すとき)、そのままだと誤差伝搬できない。そこで、条件を緩和したsoftargmaxを用いる場合がある。これはsoftmax関数を少しいじったものである。
数式での定義
まず、通常のsoftmaxは $$ \text{softmax}(x)=\frac{\exp(x_i)}{\sum_j \exp(x_j)} $$ となる。実際に使う場合はオーバーフローを防ぐため、 $$ \text{softmax}(x)=\frac{\exp(x_i-\max(x))}{\sum_j \exp(x_j-\max(x))} $$
という式を用いる。ただし、$\max(x)$は配列$x$における最大値である。
一方で、softargmaxは $$ \text{softargmax}(x)=\sum_i \frac{\exp(\beta x_i-\max(\beta x))}{\sum_j \exp(\beta x_j-\max(\beta x))}\cdot i $$ となる。$\beta$は大きな数(100以上あれば十分か?)である。
一方で、softargmaxは $$ \text{softargmax}(x)=\sum_i \frac{\exp(\beta x_i-\max(\beta x))}{\sum_j \exp(\beta x_j-\max(\beta x))}\cdot i $$ となる。$\beta$は大きな数(100以上あれば十分か?)である。
2次元の場合での実装
2次元の場合で実装してみた。
import numpy as np def softmax(x, beta=10): c = np.max(beta*x) ex = np.exp(beta*x - c) sum_ex = np.sum(ex) return ex / sum_ex def softargmax_2D(x): assert x.ndim == 2, "x dim must be 2" s = softmax(x) x = np.arange(x.shape[1]) y = np.arange(x.shape[0]) xx, yy = np.meshgrid(x, y) xmax = np.sum(s*xx) ymax = np.sum(s*yy) return ymax, xmax # toy array arr = np.random.randint(0, 100, (4, 4)) print("array\n", arr) # argmax argmax = np.argwhere(arr.max() == arr) print("\nargmax:", argmax[0,0], argmax[0,1]) # softargmax ymax, xmax = softargmax_2D(arr) print("softargmax: (float)", ymax, xmax) print("softargmax: (round) ", int(round(ymax)), int(round(xmax)))
結果は次のようになった。
array
[[13 73 41 82]
[44 69 29 26]
[37 3 14 43]
[38 50 9 45]]
argmax: 0 3
softargmax: (float) 3.4811068399043105e-57 3.0
softargmax: (round) 0 3
ヒートマップに対しての使用例
ヒートマップを適当に生成して、argmaxとsoftargmaxを比較してみた。
ヒートマップ
argmax (最大点を白丸で表示)
softmax
※ディスプレイについてる白いゴミにしか見えないが、最大点を中心としたデルタ関数っぽいピークが存在している。
softargmax (最大点を白丸で表示)
import cv2 import numpy as np def GaussianMask(sizex,sizey, sigma=10, center=None,fix=1): """ sizex : mask width sizey : mask height sigma : gaussian Sd center : gaussian mean fix : gaussian max return gaussian mask """ x = np.arange(0, sizex, 1, float) y = np.arange(0, sizey, 1, float) x, y = np.meshgrid(x,y) if center is None: x0 = sizex // 2 y0 = sizey // 2 else: if np.isnan(center[0])==False and np.isnan(center[1])==False: x0 = center[0] y0 = center[1] else: return np.zeros((sizey,sizex)) return fix*np.exp(-4*np.log(2) * ((x-x0)**2 + (y-y0)**2) / sigma**2) def Pos2Densemap(fix_arr, width, height): heatmap = np.zeros((H,W), np.float32) for n_subject in range(fix_arr.shape[0]): heatmap += GaussianMask(W, H, 33, (fix_arr[n_subject,0],fix_arr[n_subject,1]), fix_arr[n_subject,2]) # Normalization heatmap = heatmap/np.amax(heatmap) heatmap = heatmap*255 heatmap = heatmap.astype("uint8") #heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) return heatmap def softmax(x, beta=100): x = x/np.amax(x) c = np.max(beta*x) ex = np.exp(beta*x - c) sum_ex = np.sum(ex) return ex / sum_ex def softargmax_2D(x): assert x.ndim == 2, "x dim must be 2" s = softmax(x) x = np.arange(x.shape[1]) y = np.arange(x.shape[0]) xx, yy = np.meshgrid(x, y) xmax = np.sum(s*xx) ymax = np.sum(s*yy) return ymax, xmax if __name__ == '__main__': # Generate toy data num_subjects = 40 H, W = 256, 256 fix_arr = np.random.randn(num_subjects,3) fix_arr -= fix_arr.min() fix_arr /= fix_arr.max() fix_arr[:,0] *= W fix_arr[:,1] *= H # Create heatmap heatmap = Pos2Densemap(fix_arr, W, H) heatmap_c = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) cv2.imshow('heatmap', heatmap_c) cv2.waitKey(0) cv2.destroyAllWindows() cv2.imwrite("heatmap.png", heatmap_c) # argmax argmax = np.argwhere(heatmap.max() == heatmap) heatmap_argmax = cv2.circle(heatmap_c, (argmax[0,1], argmax[0,0]), 2, (255, 255, 255), -1) cv2.imshow('argmax', heatmap_argmax) cv2.waitKey(0) cv2.destroyAllWindows() cv2.imwrite("argmax.png", heatmap_argmax) # softmax heatmap_softmax = softmax(heatmap)*255 heatmap_softmax = heatmap_softmax.astype("uint8") heatmap_softmax = cv2.applyColorMap(heatmap_softmax, cv2.COLORMAP_JET) cv2.imshow('softmax', heatmap_softmax) cv2.waitKey(0) cv2.destroyAllWindows() cv2.imwrite("softmax.png", heatmap_softmax) # softargmax ymax, xmax = softargmax_2D(heatmap) heatmap_softargmax = cv2.circle(heatmap_c, (int(np.round(xmax)), int(np.round(ymax))), 2, (255, 255, 255), -1) cv2.imshow('softargmax', heatmap_softargmax) cv2.waitKey(0) cv2.destroyAllWindows() cv2.imwrite("softargmax.png", heatmap_softargmax)
まあ、何とかなってる。
Chainerとかで使う場合にはFunctionを組み合わせるか、chainer.functionを新しく実装するかのいずれかが必要。
コメントをお書きください