系列データを扱うネットワークの誤差逆伝播の話(途中)

お気付きの点がありましたらご指摘いただけますと幸いです。

  1. Paul Vicol, Luke Metz, Jascha Sohl-Dickstein. Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies. In Proceedings of the 38th International Conference on in Machine Learning (ICML 2021), 2021.

    Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies

  2. A Gentle Introduction to Backpropagation Through Time
f:id:cookie-box:20211104100722p:plain:w60

参考文献 1. は RNN の学習に進化戦略(ES)を工夫して適用することによってバイアス・バリアンスの抑制、学習の高速化、省メモリ化ができるというものでした。その具体的な手法(パーシステント進化戦略)は、参考文献 1. の 6 ページ目から読み取るに、以下のようなものではないかと思います。参考文献に倣って先に単純な適用を示し、次にそれをどう変更するかを文字色を変えて示します。

進化戦略を単純に適用した RNN の学習(ES)

入力データを K ステップずつの区間に区切り、区間ごとに期待する出力値を用意しておく。このデータに最適化するように RNN のパラメータを現在の値 \theta から更新することを考える( \thetad 次元)。
探索に用いる粒子の数を N とする。
RNN の隠れ状態ベクトルに初期値をセットする。s \leftarrow s_0
最初の区間に対して以下のように最適化を実施する。
  • 求める d 次元の勾配  \hat{g}^{\rm ES} にゼロベクトルをセットする。\hat{g}^{\rm ES} \leftarrow 0
  •  i = 1, \cdots, N に対して以下の処理を実行する。
    •  i 回目の処理の開始時、 i が奇数であれば  N(0, \sigma^2 I_d) から探索方向  \epsilon^{(i)} を取り出す。
       i が偶数であれば  \epsilon^{(i)} = - \epsilon^{(i-1)} とする。
    • パラメータを  \epsilon^{(i)} ずらして損失を求める(隠れ状態も s' に更新されるが利用しない)。
       s', \hat{L}_K^{(i)} \leftarrow {\rm unroll}(s, \theta + \epsilon^{(i)}, K)
    •  \epsilon^{(i)} ずらした箇所の損失が  \hat{L}_K^{(i)} であることに基づき \hat{g}^{\rm ES} を更新する。
      \hat{g}^{\rm ES} \leftarrow \hat{g}^{\rm ES} + \epsilon^{(i)} \hat{L}_K^{(i)}
  • ループし終わったら \hat{g}^{\rm ES} をループ回数と探索方向の分散で割る。
    \hat{g}^{\rm ES} \leftarrow \hat{g}^{\rm ES} / (N \sigma^2)
  • 先に次の区間に対する最適化のために現在のパラメータで隠れ状態 s を更新しておく。
     s, \hat{L}_K \leftarrow {\rm unroll}(s, \theta, K)
  • \hat{g}^{\rm ES} と逆方向にパラメータを更新する(\alpha は学習率)。
    \theta \leftarrow \theta - \alpha \hat{g}^{\rm ES}
2番目以降の区間に対しても更新された s を用いて同様に最適化を重ねていく。

パーシステント進化戦略を適用した RNN の学習(PES

入力データを K ステップずつの区間に区切り、区間ごとに期待する出力値を用意しておく。このデータに最適化するように RNN のパラメータを現在の値 \theta から更新することを考える( \thetad 次元)。
探索に用いる粒子の数を N とする。
粒子の数だけ RNN の隠れ状態ベクトルを用意して初期値をセットする。s^{(i)} \leftarrow s_0
粒子の数だけ探索方向を蓄積するためのベクトルを用意してゼロベクトルをセットする。\xi^{(i)} \leftarrow 0
最初の区間に対して以下のように最適化を実施する。
  • 求める d 次元の勾配  \hat{g}^{\rm PES} にゼロベクトルをセットする。\hat{g}^{\rm PES} \leftarrow 0
  •  i = 1, \cdots, N に対して以下の処理を実行する。
    •  i 回目の処理の開始時、 i が奇数であれば  N(0, \sigma^2 I_d) から探索方向  \epsilon^{(i)} を取り出す。
       i が偶数であれば  \epsilon^{(i)} = - \epsilon^{(i-1)} とする。
    • パラメータを  \epsilon^{(i)} ずらして損失を求める。この粒子用の隠れ状態  s^{(i)} も更新する。
       s^{(i)}, \hat{L}_K^{(i)} \leftarrow {\rm unroll}(s^{(i)}, \theta + \epsilon^{(i)}, K)
    • この粒子の探索方向を蓄積する。 \xi^{(i)} \leftarrow \xi^{(i)} + \epsilon^{(i)}
    • ここまで  \xi^{(i)} ずらしてきた箇所の損失が  \hat{L}_K^{(i)} であることに基づき \hat{g}^{\rm PES} を更新する。
      \hat{g}^{\rm PES} \leftarrow \hat{g}^{\rm PES} + \xi^{(i)} \hat{L}_K^{(i)}
  • ループし終わったら \hat{g}^{\rm PES} をループ回数と探索方向の分散で割る。
    \hat{g}^{\rm PES} \leftarrow \hat{g}^{\rm PES} / (N \sigma^2)
  • 先に次の区間に対する最適化のために現在のパラメータで隠れ状態 s を更新しておく。
  • \hat{g}^{\rm PES} と逆方向にパラメータを更新する(\alpha は学習率)。
    \theta \leftarrow \theta - \alpha \hat{g}^{\rm PES}
2番目以降の区間に対しては、更新された各粒子の隠れ状態 s^{(i)} と探索方向の蓄積 \xi^{(i)} を用いて同様に最適化を重ねていく。
…まず、どちらの方法でも、直接いくつかの点を探索してしまい、損失が最も小さくなる方向を探していると思います。RNN では勾配消失/爆発が生じやすく誤差逆伝播が不安定なためこうしていると思います。ただ、単純な ES では、一度最初の区間に対して最適化を実施すると、元々どんなパラメータであってそれをどう更新したか、という情報は失われてしまいます。

つづいたらつづく