雑記: RNN の勾配が消失したり爆発したりする話のメモ

シチュエーション(まずL+1層パーセプトロンの場合)
  •  0 層目を入力層、L 層目を出力層として、0, 1, \cdots, L 層が連なっているとする。
  •  w_{jk}^l を、 l-1 層目の  k 番目のニューロンから  l 層目の  j 番目のニューロンへの接続の重みとし、 b_j^ll 層目の j 番目のニューロンのバイアスとする(  l =1, \cdots. L )。
  •  z_{j}^ll 番目の j 番目のニューロンの活性化前の値、 a_{j}^ll 番目の j 番目のニューロンの活性化後の値とする。つまり、活性化関数を  \sigma(\cdot) とすると、以下の関係が成り立つ。
     \displaystyle a_j^l = \sigma( z_j^l ) = \sigma \left( \sum_k w_{jk}^l a_k^{l-1} + b_j^l \right)
    •  z_{j}^l a_{j}^lニューラルネットワークへの入力ベクトル x に依存する  z_{j}^l(x) a_{j}^l(x) だが、いまはある入力ベクトルに固定されていると考えて (x) を省略する。
    • それぞれの文字の下の添え字を取ってベクトル(重みについては行列)とみなすと  a^l = \sigma( w^l a^{l-1} + b^l ) ともかける。
  • コスト関数を  C とする(例えば正解  yニューラルネットワークの出力  a^L の2乗誤差  \displaystyle \frac{1}{2}|| y - a^L ||^2 など)。
やりたいこと
  • 任意の  l, j, k について  \displaystyle \frac{\partial C}{\partial w_{jk}^l} 及び  \displaystyle \frac{\partial C}{\partial b_{j}^l} を求めたい(これらを求める1つの手続きを知りたい)。
やり方

まず、C の最終出力 a_j^L に関する偏微分  \displaystyle \frac{\partial C}{\partial a_j^L} はただちに求まる。
次に、C の最終出力の活性化前の値 a_j^L に関する偏微分を考えると、上の  \displaystyle \frac{\partial C}{\partial a_j^L} を利用して以下のようになる。

 \displaystyle \frac{\partial C}{\partial z_j^L} = \frac{\partial C}{\partial a_j^L} \frac{\partial a_j^L}{\partial z_j^L} = \frac{\partial C}{\partial a_j^L} \sigma' (z_j^L)

さらに、最終層の1つ手間の層の活性化前の値 a_j^{L-1} に関する偏微分を考える。最終層の全ニューロンの活性化前の値に関する偏微分をつかって以下のようにかける。
 \displaystyle \frac{\partial C}{\partial z_j^{L-1}} = \sum_k \frac{\partial C}{\partial z_k^L} \frac{\partial z_k^L}{\partial z_j^{L-1}} = \sum_k \frac{\partial C}{\partial z_k^L} \frac{\partial \displaystyle \sum_m \Bigl( w_{km}^{L} \sigma(z_m^{L-1}) + b_k^{L} \Bigr) }{\partial z_j^{L-1}} = \sum_k \frac{\partial C}{\partial z_k^L} w_{kj}^{L} \sigma ' (z_j^{L-1})

この要領で遡っていけば、1層目の  z_j^1 に関する偏微分までかける。つまり、ネットワーク中の全てのニューロン  z_j^l に関する C偏微分が明示的に求まる。
そうなると任意の  l, j, k について  \displaystyle \frac{\partial C}{\partial w_{jk}^l} 及び  \displaystyle \frac{\partial C}{\partial b_{j}^l} を求めるのは容易で、つまり、以下のようにかける。
 \displaystyle \frac{\partial C}{\partial w_{jk}^l} = \frac{\partial C}{\partial z_j^l} \frac{\partial z_j^l}{\partial w_{jk}^l} = \frac{\partial C}{\partial z_j^l} \frac{\sum_k w_{jk}^l a_k^{l-1} + b_j^l}{\partial w_{jk}^l} = \frac{\partial C}{\partial z_j^l} a_k^{l-1}

 \displaystyle \frac{\partial C}{\partial b_j^l} = \frac{\partial C}{\partial z_j^l} \frac{\partial z_j^l}{\partial b_j^l} = \frac{\partial C}{\partial z_j^l} \frac{\sum_k w_{jk}^l a_k^{l-1} + b_j^l}{\partial b_j^l} = \frac{\partial C}{\partial z_j^l}

RNN の場合

入力層と出力層しかないニューラルネットワークを考える。出力層が1層しかないので a_uz_u の右上の添え字を省く。この出力層をまた入力層につなげてループさせるとする。すると、何回ループしたときの出力なのかの区別が必要なので、a_u(t)z_u(t) とかく。\displaystyle \delta_u(t) \equiv \frac{\partial C}{\partial z_u(t)} とおく。
1ループ前への伝播は以下のようになる。

\displaystyle \frac{\partial \delta_v(t-1)}{\partial \delta_u(t)} = f' \bigl( z_v(t-1) \bigr) w_{uv}
q ループ前への伝播は以下のようになる。
\begin{eqnarray*}\frac{\partial \delta_v(t-q)}{\partial \delta_u(t)} &=& \sum_k \frac{\partial \delta_v(t-q)}{\partial \delta_k(t-q+1)} \frac{\partial \delta_k(t-q+1)}{\partial \delta_u(t)} \\ &=& \sum_k f' \bigl( z_v(t-q) \bigr) w_{kv} \frac{\partial \delta_k(t-q+1)}{\partial \delta_u(t)} \\ &=& f' \bigl( z_v(t-q) \bigr) \sum_k w_{kv} \frac{\partial \delta_k(t-q+1)}{\partial \delta_u(t)}\end{eqnarray*}
途中経過がわかりやすいようにユニットの添え字を取り直す。時刻 t のユニット l_0 から時刻 t-q のユニット l_q への伝播は以下のようになる。
\begin{eqnarray*} \frac{\partial \delta_{l_{q}}(t-q)}{\partial \delta_{l_0}(t)} &=& f' \bigl( z_{l_{q}}(t-q) \bigr) \sum_{l_{q-1}} w_{{l_{q-1}}{l_{q}}} \frac{\partial \delta_{l_{q-1}}(t-q+1)}{\partial \delta_{l_0}(t)} \\ &=& f' \bigl( z_{l_{q}}(t-q) \bigr) \sum_{l_{q-1}} w_{{l_{q-1}}{l_{q}}} \left[ f' \bigl( z_{l_{q-1}}(t-q+1) \bigr) \sum_{l_{q-2}} w_{{l_{q-2}}l_{q-1}} \frac{\partial \delta_{l_{q-2}}(t-q+2)}{\partial \delta_{l_0}(t)} \right] \\ &=& \sum_{l_{q-1}} \sum_{l_{q-2}} \left[ \prod_{m=q-1}^{q} \Bigl( f' \bigl( z_{l_{m}}(t-m) \bigr)  w_{{l_{m-1}}{l_{m}}} \Bigr) \frac{\partial \delta_{l_{q-2}}(t-q+2)}{\partial \delta_{l_0}(t)} \right] \\ &=& \sum_{l_{q-1}} \cdots \sum_{l_{1}} \left[ \prod_{m=2}^{q} \Bigl( f' \bigl( z_{l_{m}}(t-m) \bigr)  w_{{l_{m-1}}{l_{m}}} \Bigr) \frac{\partial \delta_{l_{1}}(t-1)}{\partial \delta_{l_0}(t)} \right] \\ &=& \sum_{l_{q-1}} \cdots \sum_{l_{1}} \prod_{m=1}^{q} \Bigl( f' \bigl( z_{l_{m}}(t-m) \bigr)  w_{{l_{m-1}}{l_{m}}} \Bigr) \end{eqnarray*}
よって、時刻 t でのユニット l_0 の誤差を、時刻 t-q のユニット l_q の誤差まで伝播させるのに、m=1, 2, \cdots, q に対して f' \bigl( z_{l_{m}}(t-m) \bigr)  w_{{l_{m-1}}{l_{m}}} が掛け合わさっていく。そのため、さかのぼるステップ数 q が大きいとき、絶対値の組み合わせによっては、勾配は発散するか、消滅する(これは RNN に限った話ではないが、通常の MLP では多層化のレベルはたかが知れているのに対して、RNN で長期の依存性を学習したい場合は特に問題になってくる; また、同じ重み w_{uv} を何回も通る伝播パスが存在する)。