追記: この記事には 2 通りの方法をかいているが前者の方法は遅いので後者の方法(つまり参考文献 [1] の方法)でやるのがよい。
Presto で AUC を計算したいとします。ちょうど過去の記事(雑記: AUC の話とその scikit-learn での計算手順の話 - クッキーの日記)に具体例に対する計算方法があったのでまずこれと同じデータを用意します(以下)。
参考文献
- SQLでAUCを算出する方法 |Dentsu Digital Tech Blog|note(2022年6月22日参照).
- Prestoでは集計関数をWINDOW関数として扱える | 分析ノート(2022年6月22日参照).
Presto で AUC を計算したいとします。ちょうど過去の記事(雑記: AUC の話とその scikit-learn での計算手順の話 - クッキーの日記)に具体例に対する計算方法があったのでまずこれと同じデータを用意します(以下)。
WITH data AS ( SELECT CAST(label AS INT) AS label, CAST(score AS INT) AS score FROM ( SELECT '0,0,0,1,1,0,1,1' AS label_, '2,1,2,4,2,1,3,5' AS score_ ) x CROSS JOIN UNNEST( split(x.label_, ','), split(x.score_, ',') ) AS t(label, score) ),
上の WITH 句に続けて AUC の計算を実装すると以下のようになります。実行すると 0.9375 と出てきて大丈夫です → 追記: しかしこの実装は巨大データには遅いです。
-- ROCカーブの座標 tpr_fpr AS ( -- ROCカーブが原点から開始することを保証するために以下の一行を追加する SELECT 0 AS row_num, 0.0 AS tpr, 0.0 AS fpr UNION ALL -- ROCカーブの座標 SELECT ROW_NUMBER() OVER(ORDER BY threshold DESC) AS row_num, CAST(COUNT(IF(label = 1 AND predict = 1, 1, NULL)) AS DOUBLE) / COUNT(IF(label = 1, 1, NULL)) AS tpr, CAST(COUNT(IF(label = 0 AND predict = 1, 1, NULL)) AS DOUBLE) / COUNT(IF(label = 0, 1, NULL)) AS fpr FROM ( -- 各閾値候補に全データを LEFT JOIN してその閾値での判定を付加する SELECT threshold, label, score, CASE WHEN score >= threshold THEN 1 ELSE 0 END AS predict FROM ( SELECT DISTINCT score AS threshold FROM data ) LEFT JOIN data ON TRUE ) GROUP BY threshold ) -- 台形則でカーブ下の面積を算出 SELECT SUM(area) AS auc FROM ( SELECT -- 0.5 * (上底 + 下底) * 高さ 0.5 * (a.tpr + b.tpr) * (b.fpr - a.fpr) AS area FROM tpr_fpr a JOIN tpr_fpr b ON a.row_num + 1 = b.row_num )
(0.9375,)
しかし、インターネット上でみつけた参考文献 [1] をみると上の実装よりもずっとすっきりしていることがわかります。この要因は、参考文献 [2] にあるように集計関数を WINDOW 関数として扱っていなかったことと、LAG 関数を利用していなかったことにあるので、この 2 点を直すと以下になります。ちなみに参考文献 [1] はラベルが 1 か 0 であることを前提にもっとすっきりさせているので参考文献 [1] をみてください。
tpr_fpr AS ( -- ROCカーブの座標 SELECT score, CAST(COUNT(IF(label = 1, 1, NULL)) OVER(ORDER BY score DESC) AS DOUBLE) / COUNT(IF(label = 1, 1, NULL)) OVER() AS tpr, CAST(COUNT(IF(label = 0, 1, NULL)) OVER(ORDER BY score DESC) AS DOUBLE) / COUNT(IF(label = 0, 1, NULL)) OVER() AS fpr FROM data ) -- 台形則でカーブ下の面積を算出 SELECT SUM(area) AS auc FROM ( SELECT -- 0.5 * (上底 + 下底) * 高さ 0.5 * (tpr + LAG(tpr, 1, 0.0) OVER (ORDER BY score DESC)) * (fpr - LAG(fpr, 1, 0.0) OVER (ORDER BY score DESC)) AS area FROM tpr_fpr )
(0.9375,)
なお、ROC カーブはスコアが取りうる値が限られている(例えばいくつかの離散値である)ときはデータ数に比べて実質的なカーブの座標の点数は小さくなりますが、後者の実装では必ずデータ数だけの座標(実際には曲線の形を変えない)を出すことにはなります。それを出したくないなら CASE 文を使って必要なときだけ集計 WINDOW 関数すればいいですがそれで処理時間に恩恵があるのかわかりません。では前者の実装はというとデータ数が大きくスコアが連続的なときは JOIN ON TRUE の結果が巨大になりそうですがだからだめかはやってみていないのでわかりません → 追記:やってみたら前者が遅かったです。なんやかんやいう以前に内部で JOIN ON TRUE するオーバーヘッドがすごいんだと思います(適当)。