論文読みメモ: 深層自己符号化器+混合ガウスモデルによる教師なし異常検知(その2)

以下の論文を読みます。

Bo Zong, Qi Song, Martin Renqiang Min, Wei Cheng, Cristian Lumezanu, Daeki Cho, Haifeng Chen. Deep Autoencoding Gaussian Mixture Model for Unsupervised Anomaly Detection. International Conference on Learning Representations, 2018. https://openreview.net/forum?id=BJJLHbb0-
※ キャラクターに元ネタはないです。お気付きの点がありましたらお手数ですがコメント等にてご指摘ください。
前回:その1 / 次回:まだ
f:id:cookie-box:20180513082851p:plain:w60

異常検知するのに有効なデータの密度推定をするのに、次元削減する「圧縮ネットワーク」と、次元削減後の点が混合正規分布のうち何番目の分布に帰属するか予測する「推定ネットワーク」を同時に学習しようという話でした。時間の都合で2節の Related Work は後回しにしますね。3節も3.3節まではイントロの時点で透けてた内容なんで3.4節までとばしますが、ここで示される DAGMM の目的関数は以下です。 \theta_e, \theta_d, \theta_m はエンコーダ、デコーダ、推定ネットワークのパラメータです。
   \displaystyle J(\theta_e, \theta_d, \theta_m) = \frac{1}{N} \sum_{i=1}^N L(x_i, x_i') + \frac{\lambda_1}{N} \sum_{i=1}^N E(z_i) + \lambda_2 P(\hat{\Sigma})
このうち  L(x_i, x_i') はサンプル  x_i とその再構築  x_i' の誤差関数で、通常L2ノルム  L(x_i, x_i') = || x_i - x_i' ||_2^2 にするのがよいとか書いてありますね。 E(z_i) はサンプル  x_i に対応する確率密度の対数にマイナスをかけたものですね。 z_i \equiv [{z_c}_i , \, f(x_i, x_i')]x_i の低次元表現です。x_i を自己符号化器でエンコードした  {z_c}_i に加えて、再構築エラー  f(x_i, x_i') が concatinate されているのに注意ですね。 f(x_i, x_i') はこれも  x_i x_i' の誤差=距離の関数ですが、必ずしも距離の定義を1つに絞る必要はなく、色々な尺度での誤差を並べて多次元にしてもよいだろうということみたいです。この論文では  E(z_i) のことをエネルギーといっていますね。 E(z_i) はネットワークの学習時につかうのみならず、学習し終えていざ異常検知するという段で「閾値よりも高エネルギーなら異常」というようにつかうと。

f:id:cookie-box:20180513082908p:plain:w60

前回、「圧縮ネットワーク」の目的関数は再構築エラーで「推定ネットワーク」の目的関数は尤度のはずだけどどのように2つのネットワークを同時に学習するんだろう、って気になってたけど、双方を重み付きで足し合わせたものを全体の目的関数とするんだね。ただこの場合、重みのバランスってどうなるんだろう…。あと、最後の項の  P(\hat{\Sigma}) は?

f:id:cookie-box:20180513082851p:plain:w60

正則化項です。次元削減と密度推定の結果、得られた密度が特徴空間内で各サンプルに対応する点の周りに局在してしまったら困りますよね。そもそも次元削減と密度推定によってサンプルとサンプルの隙間の密度を埋めたかったのに。このような悲惨な事態を避けるために以下の正則化項を導入します。
   \displaystyle P(\hat{\Sigma}) = \sum_k \sum_j \frac{1}{(\hat{\Sigma}_{k})_{j,j}}
 (\hat{\Sigma}_{k})_{j,j} は混合されている分布のうち  k 番目の分布の分散共分散行列の  j 番目の対角要素です。どの分布のどの対角要素もゼロになってほしくないので、このようなペナルティを導入するというわけです。この正則化項を入れることによって、自己符号化器部分の性能が、事前学習した自己符号化器並みによくなるとか(イントロ部分でメモに記していませんでしたが、自己符号化器は事前学習した方が自己符号化器部分の性能はよくなるが、今回の目的でそれをしてしまうと密度推定側からの要請で自己符号化器側を modify するということがしづらくなってしまうので事前学習はしないというスタンスです、すみません)。

f:id:cookie-box:20180513082908p:plain:w60

確かに、避けるべきトラップがそれだけなのかはぱっとわからないけど、その事態を避けなければいけないのはマストだね。でもその正則化項もどんな重みで足せばいいのかわからないな…。 \lambda_1, \, \lambda_2 はどう決めるの?

f:id:cookie-box:20180513082851p:plain:w60

 \lambda_1=0.1, \, \lambda_2=0.005 がオススメらしいです。おそらく試行錯誤の結果なんじゃないでしょうか。

f:id:cookie-box:20180513082908p:plain:w60

そっか。

f:id:cookie-box:20180513082851p:plain:w60

3.5節はDAGMM の密度推定プロセスを、通常の variational inference と対比するとこうって話でしょうか。ただモデルの選定基準や性能に直接関わる話ではなさそうなのでとばしますね。3.6節の内容は上で少し触れました。なので次は4節の検証ですね。

f:id:cookie-box:20180513082908p:plain:w60

最適化アルゴリズムの話は?

f:id:cookie-box:20180513082851p:plain:w60

少なくとも3節にはないですね。目的関数はわかったんで、プログラミング部の人に「これを最小化してください」って実装してもらえばいいんじゃないですか?

f:id:cookie-box:20180513082908p:plain:w60

迷惑だよね!?

f:id:cookie-box:20180513082851p:plain:w60

今回検証に用いられているベンチマークデータセットですが、論文にはそれぞれ何のデータか明記されていないので、軽く調べておきましょう(リンク先は論文に記述されているリポジトリとは限りません)。

KDDCUP
(KDD Cup 99 Data Set)
これは「マルウェア攻撃」と「正常通信」の両方を含む通信データで、494,021件のデータが含まれています(これは10%抽出版の件数であって、元々はサイズが約10倍のデータセットです)。「正常通信」の方が20%と少ないので、DAGMM の検証では「正常通信」の側を「検知すべき異常」とみなしています。調べるとこのデータセットは「マルウェア攻撃」にも22種類の攻撃が含まれているらしいので、だから混合分布によるモデリングが適するのかもしれません。
http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html
ThyroidThyroidとは「甲状腺」ですね。3772件のデータには「甲状腺機能低下」「やや低下」「正常」の3クラスのデータが含まれ、検証ではこのうち「甲状腺機能低下」を「検知すべき異常」としているようです。
http://odds.cs.stonybrook.edu/thyroid-disease-dataset/
ArrhythmiaArrhythmiaは「不整脈」です。452件のデータにはそれぞれ1~16のラベルが付されており、1が「正常」、2~15がそれぞれタイプの異なる不整脈、16がその他という意味とのことです。DAGMM の検証では、データの絶対数が少ないラベル 3~5, 7~9, 14, 15 を異常とみなしたようです。
https://archive.ics.uci.edu/ml/datasets/arrhythmia
KDDCUP-Revこの論文の著者が KDDCUP データセットから改めて抽出したデータセットですね。元々の KDDCUP データセットは全体の20%が「正常通信」で残りの80%が「マルウェア攻撃」ですが、DAGMM による異常検知でも「正常通信」を正常として取り扱うため、「正常通信」のデータは全て保持しておいて、それに加えて「正常通信」のデータサイズの20%の件数にあたる「マルウェア攻撃」をランダムに抽出して、全体の80%が「正常通信」であるような新しいデータセットを用意したということのようです。

(次回があれば)つづく