スパイクトリガー平均・共分散の計算法 (Python)

(著)山拓

CourseraComputational Neuroscienceの2週目の課題をベースにして、スパイクトリガー平均(Spike-triggered average; STA)とスパイクトリガー共分散(Spike-triggered covariance; STC)をPythonで実装します。Week2のノートを先に参照してください。

 

Spike-triggered analysisの目的

スパイクが刺激の中のどのような特徴量によって引き起こされるかを知るために、フィルタを推定することのが目的。

 

ハエのH1ニューロン

まず課題の説明。

 

ハエにはH1という水平方向の運動に対して反応するニューロンが視覚野に存在します。次図はハエの脳の構造です。画像はWikipediaより引用しました。

 

H1ニューロンが発火する刺激はなんであるかを調べるため、白黒の縞模様を左右に動かしながら固定したハエに見せ、H1ニューロンの電位変化を記録します。

実験の様子はYoutubeで見ることができます。

Spike-triggered average; STA

STAの計算

STAはスパイクが起こった時点から一定の時間前(temporal window)までの間の刺激ベクトルを集めて平均を取ったものである。式で書くと次のようになる。 $$ \text{STA} =\frac{1}{N_s} \sum_{n=1}^{N_s} \boldsymbol{s}(t_n) $$ $t_n$は計測時間内における$n$番目のスパイク、$\boldsymbol{s}(t_n)$は刺激のベクトル(テンソル)で(temporal window) × (sのベクトル次元)となっている。$N_s$はスパイクの総数。temporal windowを$t_w$とすると、 $$ \boldsymbol{s}(t_n)=[s(t_n), s(t_n-1),\cdots, s(t_n-t_w)] $$ となる。当然だが、$\boldsymbol{s}(t_n)$とSTAの次元は同じである。

Week2の課題の解答

まずはSTAを求める関数を定義する。

import numpy as np

def compute_sta(stim, rho, num_timesteps):
    """Compute the spike-triggered average from a stimulus and spike-train.
    
    Args:
        stim: stimulus time-series 
        rho:  spike-train time-series
        num_timesteps: how many timesteps to use in STA
        
    Returns:
        spike-triggered average for num_timesteps timesteps before spike"""
    
    sta = np.zeros((num_timesteps,))
    
    # This command finds the indices of all of the spikes that occur
    # after 300 ms into the recording.
    # 150タイムステップ以降のスパイクの記録
    spike_times = rho[num_timesteps:].nonzero()[0] + num_timesteps
    
    # Fill in this value. Note that you should not count spikes that occur
    # before 300 ms into the recording.

    num_spikes = np.count_nonzero(spike_times) #53583
    
    # Compute the spike-triggered average of the spikes found.
    # To do this, compute the average of all of the vectors
    # starting 300 ms (exclusive) before a spike and ending at the time of
    # the event (inclusive). Each of these vectors defines a list of
    # samples that is contained within a window of 300 ms before each
    # spike. The average of these vectors should be completed in an
    # element-wise manner.
    
    for t in spike_times:
        sta += stim[t-num_timesteps:t]
    
    sta /= num_spikes
    
    return sta

 

このcompute_sta.pyからcompute_staをloadし、matfileのデータを分析する。 

結果は次図のようになった(色々付け足している)。

#Quiz 2 code.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from compute_sta import compute_sta
import scipy.io

# Load mat data
data = scipy.io.loadmat('c1p8.mat')

""" When using pickle data file
import pickle
FILENAME = 'c1p8.pickle'

with open(FILENAME, 'rb') as f:
    data = pickle.load(f)
"""

# stim: Stimulus time-series data (600000)
stim = data['stim']
stim = np.reshape(stim, (-1)) # Need reshape when useing mat data

"""rho : Spike-train time-series (600000,)
1 : Spike 
0 : None
"""
rho = data['rho']
rho = np.reshape(rho, (-1)) # Need reshape when useing mat data

# 2msごとにスパイクの有無を記録 -> サンプリングレートは500Hz
sampling_period = 2 # in ms

# 300ms = 150timesteps
num_timesteps = 150

sta = compute_sta(stim, rho, num_timesteps)

time = (np.arange(-num_timesteps, 0) + 1) * sampling_period

#STA :Leaky integration 
plt.figure(figsize=(10,5))
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[2, 1]) 
plt.subplot(gs[:,0])
plt.plot(time, sta)
plt.xlabel('Time (ms)')
plt.ylabel('Stimulus')
plt.title('Spike-Triggered Average')

#Normalized stimulus
plt.subplot(gs[0,1])
N_samples_represent = 1000
stim_selection = stim[:N_samples_represent]
stim_selection_norm = stim_selection / stim_selection.max()
time_selection = np.arange(0, N_samples_represent) * (sampling_period / 1000.0)
plt.plot(time_selection, stim_selection_norm)
plt.xlabel('Time (s)')
plt.title('Normalized stimulus for the first {} samples'.format(N_samples_represent))

#Spike train
plt.subplot(gs[1,1])
rho_selection = rho[:N_samples_represent]
time_selection = np.arange(0, N_samples_represent) * (sampling_period / 1000.0)
plt.bar(time_selection, rho_selection, color='red', width=0.001)
plt.xlabel('Time (s)')
plt.title('Spike time-series for the first {} samples'.format(N_samples_represent))

plt.tight_layout()
plt.savefig("Spike-Triggered Average")
plt.show()

Spike-triggered covariance; STC

STAは1つのフィルタしか復元できないが、STCを用いれば複数のフィルタの復元ができる。

STCの計算

STC行列は次のように計算できる。 $$ \text{STC} = \frac{1}{N_s-1}\sum_{n=1}^{N_s} \left[\boldsymbol{s}(t_n)-\text{STA}\right]\left[\boldsymbol{s}(t_n)-\text{STA}\right]^T $$ STAは上で計算したベクトル(テンソル)である。

 

この後、STC行列の固有値を求める固有値問題を解く。分散共分散行列の固有値問題を解くのがPCAであり、この流れはPCAとほぼ同じである。(Week2の講義中にPCAを使えと言っていたが、これはSTCのことなのか?)

 

Pythonでの実装

Matlabの実装はJonathan Pillow先生の研究室のHPにある。 

→ spike-triggered covariance (STC) analysis - matlab code

 

それをPythonで書き直したらしいリポジトリがこちらである。

 

Neural Networkへの応用

最近、逆相関法やSpike triggered analysisを用いてArtificial Neural Networkの内部のニューロンの活動を可視化しようという話をよく聞く。

 

参考文献

・Schwartz, O, Pillow, JW, Rust, NC, and Simoncelli, EP. (2006). Spike-triggered neural characterization. Journal of Vision, 6(4):484-507 (pdf

 

・Simoncelli, EP, Paninski, L, Pillow, J, and Schwartz, O (2004). Characterization of neural responses with stochastic stimuli. In M Gazzaniga (ed.) The Cognitive Neurosciences, 3rd edition. MIT Press. (pdf

 

Spike-triggered average - Wikipedia

Spike-triggered covariance - Wikipedia