お気付きの点がありましたらご指摘いただけますと幸いです。
- Yifan Chen, Qi Zeng, Heng Ji, Yun Yang. Skyformer: Remodel Self-Attention with Gaussian Kernel and Nyström Method. In NeurIPS 2021. [Proceedings]
- カーネル法の計算量削減手法(Nyström 近似)をセルフアテンションに応用した研究であり、正定値カーネルでもなければ対称でもないソフトマックスによるセルフアテンションを、ガウシアンカーネルに置き換えて対称行列に拡張してカーネル法の土俵に持ち込んだうえで Nyström 近似を適用している。なので、「『$i$ 単語目から $j$ 単語目にどれだけ注意すべきか』をすべて突き止めよう」という枠組みは崩さずに、その $n \times n$ 行列を低コストに得ようとする低ランク近似路線になっている。
- Shengjie Luo, Shanda Li, Tianle Cai, Di He, Dinglan Peng, Shuxin Zheng, Guolin Ke, Liwei Wang, Tie-Yan Liu. Stable, Fast and Accurate: Kernelized Attention with Relative Positional Encoding. In NeurIPS 2021. [Proceedings]
- まだ関連記事はないです。
まとめ
|
Transformer の訓練に系列長の 2 乗の計算量がかかってしまうというのは、$i$ 単語目にセルフアテンションを適用したものが以下のようになるからですよね。$i$ 単語目から $j$ 単語目への重み $\tilde{a}_{i,j}$ の分子にあたる $\exp(q_i \cdot k_j / \sqrt{p})$ は $n$ 回計算する必要があり、これを全 $n$ 単語に繰り返す必要があります。
y_i = \sum_{j=1}^n \tilde{a}_{i,j} v_j = \sum_{j=1}^n \left[ \frac{\exp(\frac{1}{\sqrt{p}} q_i \cdot k_{j})}{\sum_{j'=1}^n \exp(\frac{1}{\sqrt{p}} q_i \cdot k_{j'})} \right] v_j
$$
グラム行列の固有値と固有ベクトルを近似的に低コストに得る手法だね。
- Nyström 近似 - Cookipedia(注: リンク先は筆者による適当な解釈)
訓練データをいくら増やしていっても $(1, 0, 0, \cdots )$ であるなら1つ目の固有関数はもはや1つ目のデータの点にそびえたつデルタ関数じゃないか……そりゃ無限に近づくよ……。
なるほど……いやしかし、セルフアテンション行列(※ $\tilde{a}_{i,j}$ を成分とした行列のことをこうよぶことにする)は別にグラム行列ではないですよね? 分子の $\exp(q_i \cdot k_j / \sqrt{p})$ だけみれば指数型カーネルではありますが、だとしてもセルフアテンション行列は対称行列ではありません。$i$ 単語目のクエリ $q_i$ とキー $k_i$ は一般に異なりますから。$i$ 単語目の $j$ 単語目への注意度と、$j$ 単語目の $i$ 単語目への注意度は一般に一致しません。Nyström 近似はみたところ対称行列であることにもその成分が正定値カーネルの値であることにも立脚しているようにみえますが、この手法をどうセルフアテンションに適用しようというのでしょうか……?
3 節によると、セルフアテンション行列は以下のようにかける? $\tilde{a}_{i,j}$ の分母を対角行列にもつ行列 $D$ の逆行列と、$\tilde{a}_{i,j}$ の分子のみ並べた行列 $A$(非正規化セルフアテンション行列とよびましょう)の積の形にしたのですね。
\begin{align}
\tilde{A} &=
\begin{pmatrix}
\sum_{j=1} \exp(q_1 \cdot k_{j} / \sqrt{p}) & 0 & \cdots \\
0 & \sum_{j=1} \exp(q_2 \cdot k_{j} / \sqrt{p}) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}^{-1}
\begin{pmatrix}
\exp(q_1 \cdot k_{1} / \sqrt{p}) & \exp(q_1 \cdot k_{2} / \sqrt{p}) & \cdots \\
\exp(q_2 \cdot k_{1} / \sqrt{p}) & \exp(q_2 \cdot k_{2} / \sqrt{p}) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\\
& \equiv D^{-1} A
\end{align}
$$
それで、非正規化セルフアテンション行列 $A$ は以下のように変形できますね。指数型カーネルからガウシアンカーネルをひねり出したようなイメージです。
\begin{align}
A &= \begin{pmatrix}
\exp(q_1 \cdot k_{1} / \sqrt{p}) & \exp(q_1 \cdot k_{2} / \sqrt{p}) & \cdots \\
\exp(q_2 \cdot k_{1} / \sqrt{p}) & \exp(q_2 \cdot k_{2} / \sqrt{p}) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix} \\
&= \begin{pmatrix}
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) \exp( \frac{\| k_1 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_1 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) \exp( \frac{\| k_2 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_1 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) \exp( \frac{\| k_1 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_2 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) \exp( \frac{\| k_2 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_2 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix} \\
&= \begin{pmatrix}
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_1 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_1 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_2 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_2 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\begin{pmatrix}
\exp( \frac{\| k_1 \|^2}{ 2 \sqrt{p}} ) &0& \cdots \\
0& \exp( \frac{\| k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\\
&=
\begin{pmatrix}
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) &0& \cdots \\
0& \exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\begin{pmatrix}
\exp( -\frac{\| q_1 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( -\frac{\| q_1 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\exp( -\frac{\| q_2 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( -\frac{\| q_2 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\begin{pmatrix}
\exp( \frac{\| k_1 \|^2}{ 2 \sqrt{p}} ) &0& \cdots \\
0& \exp( \frac{\| k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\\
& \equiv D_Q^{1/2} C D_K^{1/2}
\\
\end{align}
$$
それで肝心の「半正定値でない」をどう克服するかですが……$Q$ と $K$ を上下に連結してクエリとキーを一緒くたに扱うイメージですかね。その行列の各行のベクトル(論文中に文字が定義されていないので勝手に $z_i$ と置きます)は、$(z_1, \cdots, z_n, z_{n+1}, \cdots, z_{2n}) = (q_1, \cdots, q_n, k_{1}, \cdots, k_{n})$ になるはずです。$(z_1, \cdots, z_{2n}$ に対するグラム行列(カーネルはガウシアンカーネル)は以下になりますね。$\overline{B}$ の右上 $n\times n$ ブロック(あるいは左下 $n\times n$ ブロックの転置)が $C$ と一致しています。
\overline{B} \equiv
\begin{pmatrix}
\exp( -\frac{\| z_1 - z_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( -\frac{\| z_1 - z_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\exp( -\frac{\| z_2 - z_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( -\frac{\| z_2 - z_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
$$
つづいたらつづく