トランスフォーマーと素敵なカーネル法【第1話】

お気付きの点がありましたらご指摘いただけますと幸いです。

  1. Yifan Chen, Qi Zeng, Heng Ji, Yun Yang. Skyformer: Remodel Self-Attention with Gaussian Kernel and Nyström Method. In NeurIPS 2021. [Proceedings]
    • カーネル法の計算量削減手法(Nyström 近似)をセルフアテンションに応用した研究であり、正定値カーネルでもなければ対称でもないソフトマックスによるセルフアテンションを、ガウシアンカーネルに置き換えて対称行列に拡張してカーネル法の土俵に持ち込んだうえで Nyström 近似を適用している。なので、「『$i$ 単語目から $j$ 単語目にどれだけ注意すべきか』をすべて突き止めよう」という枠組みは崩さずに、その $n \times n$ 行列を低コストに得ようとする低ランク近似路線になっている。
  2. Shengjie Luo, Shanda Li, Tianle Cai, Di He, Dinglan Peng, Shuxin Zheng, Guolin Ke, Liwei Wang, Tie-Yan Liu. Stable, Fast and Accurate: Kernelized Attention with Relative Positional Encoding. In NeurIPS 2021. [Proceedings]
    • こちらもカーネル法の計算量削減手法(Random Feature)をセルフアテンションに応用した研究で、ソフトマックス自体を特徴関数の内積で表現しようとする研究の流れを汲み、相対位置エンコーディングしたい場合に拡張している。ので [1] とは異なり、「$i$ 単語目から $j$ 単語目に注意するのは計算量がかかるので先に文章全体の特徴をまとめておこう」という要約路線(勝手に命名)(の延長)である。詳しくはこちらを参照。
  • まだ関連記事はないです。

まとめ

  • 通常の Transformer では「$i$ 単語目から $j$ 単語目にどれだけ注意すべきか」をすべて計算するために、$n$ 単語の文章を処理するのに $\mathcal{O}(n^2)$ の計算量がかかる。
  • 計算量を $\mathcal{O}(n^2)$ より小さくしたいとき、以下のような路線が考えられている、と思われる。
    • スパース路線: 「$i$ 単語目から $j$ 単語目にどれだけ注意すべきか」をすべて計算するのではなく一部だけ計算する。アドホックになりがちと思われるが実践的なように思われる。
    • 低ランク近似路線: 「$i$ 単語目から $j$ 単語目にどれだけ注意すべきか」をすべて計算することはあきらめずに、この $n \times n$ 行列を低コストに得られるものに近似する。
    • 要約路線: 「$i$ 単語目から文章全体(あるいはいくつかの単語ブロック)にどれだけ注意すべきか」というようにあらかじめ注意の向き先を要約しておく。要は注意と和を交換する。
  • 他方、カーネル法は「$i$ 番目のデータと $j$ 番目のデータがどれだけ関連しているか(=グラム行列)」を中心に推測を行う。が、こちらも $n$ が大きいとき推測プロセスでの計算量が大きくなるために様々な工夫が打ち出されている。
    • 「$i$ 番目のデータと $j$ 番目のデータがどれだけ関連しているか」をもつことはあきらめずに、この $n \times n$ 行列を低ランク近似する。
    • 正定値カーネルを特徴関数に展開する。要は和と交換できるようにする。
f:id:cookie-box:20211229151958p:plain:w60

Transformer の訓練に系列長の 2 乗の計算量がかかってしまうというのは、$i$ 単語目にセルフアテンションを適用したものが以下のようになるからですよね。$i$ 単語目から $j$ 単語目への重み $\tilde{a}_{i,j}$ の分子にあたる $\exp(q_i \cdot k_j / \sqrt{p})$ は $n$ 回計算する必要があり、これを全 $n$ 単語に繰り返す必要があります。

$$
y_i = \sum_{j=1}^n \tilde{a}_{i,j} v_j = \sum_{j=1}^n \left[ \frac{\exp(\frac{1}{\sqrt{p}} q_i \cdot k_{j})}{\sum_{j'=1}^n \exp(\frac{1}{\sqrt{p}} q_i \cdot k_{j'})} \right] v_j
$$

f:id:cookie-box:20211229162343p:plain:w60

もし $\exp(q_i \cdot k_j / \sqrt{p})$ の部分が単に内積 $q_i \cdot k_j$ なら、内積の線形性 $a q_i \cdot k_j + b q_i \cdot k_{j'} = q_i \cdot (a k_j + b k_{j'})$ から、内積を取る前に内積と $\sum_{j=1}^n$ を交換できて、その $n$ 回の計算を1回にまとめられるんだけどね。

f:id:cookie-box:20211229151958p:plain:w60

$\exp(q_i \cdot k_j / \sqrt{p}) + \exp(q_i \cdot k_{j'} / \sqrt{p})$ を同じ要領でまとめることはできませんからね。セルフアテンションは内積ではありませんから……しかし、その内積ではないセルフアテンションにカーネル法の計算量削減手法を適用しようといった研究が色々とみられるのですね? NeurIPS 2021 においても複数みられます [1] [2]。例えば [1] では計算量削減のため、カーネル法における Nyström 近似を半正定値でない行列に適用する方法を編み出したと……Nyström 近似?

f:id:cookie-box:20211229162343p:plain:w60

グラム行列の固有値固有ベクトルを近似的に低コストに得る手法だね。

発想として、グラム行列の固有値固有ベクトルって、元の正定値カーネル固有値と固有関数(の各訓練データ上での値)の近似になっているんだよね。ただ正規化定数の分だけ調整は要るけどね。固有関数の直交性から、固有関数の各訓練データ上の値の 2 乗の $1/n$ 倍の和が $1$ にならなければならないから、グラム行列の単位固有ベクトルを $\sqrt{n}$ 倍したものが固有関数の各訓練データ上の値に対応する。この要領で、一部のサンプルに対応するグラム行列の固有値固有ベクトルから全データに対するグラム行列の固有値固有ベクトルを近似するのがカーネル法における Nyström 近似だね。

f:id:cookie-box:20211229151958p:plain:w60

単位固有ベクトルを $\sqrt{n}$ 倍したものが固有関数の値の列になるのですか……? それだと、もし仮にグラム行列の固有ベクトルの1本目が $(1, 0, 0, \cdots )$ であったとしたら、1つ目の固有関数の1つ目のデータ $X_1$ における値が無限に近づいていきませんか??

f:id:cookie-box:20211229162343p:plain:w60

訓練データをいくら増やしていっても $(1, 0, 0, \cdots )$ であるなら1つ目の固有関数はもはや1つ目のデータの点にそびえたつデルタ関数じゃないか……そりゃ無限に近づくよ……。

f:id:cookie-box:20211229151958p:plain:w60

なるほど……いやしかし、セルフアテンション行列(※ $\tilde{a}_{i,j}$ を成分とした行列のことをこうよぶことにする)は別にグラム行列ではないですよね? 分子の $\exp(q_i \cdot k_j / \sqrt{p})$ だけみれば指数型カーネルではありますが、だとしてもセルフアテンション行列は対称行列ではありません。$i$ 単語目のクエリ $q_i$ とキー $k_i$ は一般に異なりますから。$i$ 単語目の $j$ 単語目への注意度と、$j$ 単語目の $i$ 単語目への注意度は一般に一致しません。Nyström 近似はみたところ対称行列であることにもその成分が正定値カーネルの値であることにも立脚しているようにみえますが、この手法をどうセルフアテンションに適用しようというのでしょうか……?

f:id:cookie-box:20211229151958p:plain:w60

3 節によると、セルフアテンション行列は以下のようにかける? $\tilde{a}_{i,j}$ の分母を対角行列にもつ行列 $D$ の逆行列と、$\tilde{a}_{i,j}$ の分子のみ並べた行列 $A$(非正規化セルフアテンション行列とよびましょう)の積の形にしたのですね。

$$
\begin{align}
\tilde{A} &=
\begin{pmatrix}
\sum_{j=1} \exp(q_1 \cdot k_{j} / \sqrt{p}) & 0 & \cdots \\
0 & \sum_{j=1} \exp(q_2 \cdot k_{j} / \sqrt{p}) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}^{-1}
\begin{pmatrix}
\exp(q_1 \cdot k_{1} / \sqrt{p}) & \exp(q_1 \cdot k_{2} / \sqrt{p}) & \cdots \\
\exp(q_2 \cdot k_{1} / \sqrt{p}) & \exp(q_2 \cdot k_{2} / \sqrt{p}) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\\
& \equiv D^{-1} A
\end{align}
$$

それで、非正規化セルフアテンション行列 $A$ は以下のように変形できますね。指数型カーネルからガウシアンカーネルをひねり出したようなイメージです。

$$
\begin{align}
A &= \begin{pmatrix}
\exp(q_1 \cdot k_{1} / \sqrt{p}) & \exp(q_1 \cdot k_{2} / \sqrt{p}) & \cdots \\
\exp(q_2 \cdot k_{1} / \sqrt{p}) & \exp(q_2 \cdot k_{2} / \sqrt{p}) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix} \\
&= \begin{pmatrix}
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) \exp( \frac{\| k_1 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_1 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) \exp( \frac{\| k_2 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_1 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) \exp( \frac{\| k_1 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_2 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) \exp( \frac{\| k_2 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_2 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix} \\
&= \begin{pmatrix}
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_1 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_1 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_2 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) \exp( -\frac{\| q_2 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\begin{pmatrix}
\exp( \frac{\| k_1 \|^2}{ 2 \sqrt{p}} ) &0& \cdots \\
0& \exp( \frac{\| k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\\
&=
\begin{pmatrix}
\exp( \frac{\| q_1 \|^2}{ 2 \sqrt{p}} ) &0& \cdots \\
0& \exp( \frac{\| q_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\begin{pmatrix}
\exp( -\frac{\| q_1 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( -\frac{\| q_1 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\exp( -\frac{\| q_2 - k_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( -\frac{\| q_2 - k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\begin{pmatrix}
\exp( \frac{\| k_1 \|^2}{ 2 \sqrt{p}} ) &0& \cdots \\
0& \exp( \frac{\| k_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
\\
& \equiv D_Q^{1/2} C D_K^{1/2}
\\
\end{align}
$$
しかしガウシアンカーネルにしてみたところで $C$ は一般に半正定値にはなりません。先ほどもいったように、クエリ $q_i$ とキー $k_i$ は異なりますから。もし $k_i = q_i$ とすれば半正定値ですが。それは指数型カーネルのままであろうとそうですね。ただまあ 4 節に進むと、今回はこちらの $C$ をセルフアテンション行列として $\tilde{A}$ の代わりに使用するのですね。代用することの正当化として、「セルフアテンションの長所は限られた重要な単語にのみ注意できる点だがガウシアンカーネルも似たような挙動をする」とか「$\tilde{A}=D^{-1} A$ と $C=D_Q^{-1/2} A D_K^{-1/2}$ は似ている」とか色々ありますがその辺は実際セルフアテンションとして機能すればどうでもいいです。しかし「条件数が小さくなって訓練が安定になる」というところはそれらよりははっきりと利点のようにみえます。

それで肝心の「半正定値でない」をどう克服するかですが……$Q$ と $K$ を上下に連結してクエリとキーを一緒くたに扱うイメージですかね。その行列の各行のベクトル(論文中に文字が定義されていないので勝手に $z_i$ と置きます)は、$(z_1, \cdots, z_n, z_{n+1}, \cdots, z_{2n}) = (q_1, \cdots, q_n, k_{1}, \cdots, k_{n})$ になるはずです。$(z_1, \cdots, z_{2n}$ に対するグラム行列(カーネルはガウシアンカーネル)は以下になりますね。$\overline{B}$ の右上 $n\times n$ ブロック(あるいは左下 $n\times n$ ブロックの転置)が $C$ と一致しています。

$$
\overline{B} \equiv
\begin{pmatrix}
\exp( -\frac{\| z_1 - z_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( -\frac{\| z_1 - z_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\exp( -\frac{\| z_2 - z_1 \|^2}{ 2 \sqrt{p}} ) &
\exp( -\frac{\| z_2 - z_2 \|^2}{ 2 \sqrt{p}} ) & \cdots \\
\vdots & \vdots & \ddots
\end{pmatrix}
$$

つづいたらつづく