雑記: SQL で適当に AUC を計算すると遅い

追記: この記事には 2 通りの方法をかいているが前者の方法は遅いので後者の方法(つまり参考文献 [1] の方法)でやるのがよい。

参考文献

  1. SQLでAUCを算出する方法 |Dentsu Digital Tech Blog|note(2022年6月22日参照).
  2. Prestoでは集計関数をWINDOW関数として扱える | 分析ノート(2022年6月22日参照).



Presto で AUC を計算したいとします。ちょうど過去の記事(雑記: AUC の話とその scikit-learn での計算手順の話 - クッキーの日記)に具体例に対する計算方法があったのでまずこれと同じデータを用意します(以下)。

WITH data AS (
    SELECT 
        CAST(label AS INT) AS label,
        CAST(score AS INT) AS score
    FROM
    (
        SELECT
            '0,0,0,1,1,0,1,1' AS label_,
            '2,1,2,4,2,1,3,5' AS score_
    ) x
    CROSS JOIN
        UNNEST(
            split(x.label_, ','),
            split(x.score_, ',')
        ) AS t(label, score)
),

上の WITH 句に続けて AUC の計算を実装すると以下のようになります。実行すると 0.9375 と出てきて大丈夫です → 追記: しかしこの実装は巨大データには遅いです。

-- ROCカーブの座標
tpr_fpr AS (
    -- ROCカーブが原点から開始することを保証するために以下の一行を追加する
    SELECT
        0 AS row_num,
        0.0 AS tpr,
        0.0 AS fpr

    UNION ALL

    -- ROCカーブの座標
    SELECT
        ROW_NUMBER() OVER(ORDER BY threshold DESC) AS row_num,
        CAST(COUNT(IF(label = 1 AND predict = 1, 1, NULL)) AS DOUBLE) / COUNT(IF(label = 1, 1, NULL)) AS tpr,
        CAST(COUNT(IF(label = 0 AND predict = 1, 1, NULL)) AS DOUBLE) / COUNT(IF(label = 0, 1, NULL)) AS fpr
    FROM
    (
        -- 各閾値候補に全データを LEFT JOIN してその閾値での判定を付加する
        SELECT
            threshold,
            label,
            score,
            CASE
                WHEN score >= threshold THEN 1
                ELSE 0
            END AS predict
        FROM
        (
            SELECT DISTINCT
                score AS threshold
            FROM
                data
        )
        LEFT JOIN
            data
        ON
            TRUE
    )
    GROUP BY
        threshold
)

-- 台形則でカーブ下の面積を算出
SELECT
    SUM(area) AS auc
FROM
(
    SELECT
        -- 0.5 * (上底 + 下底) * 高さ
        0.5 * (a.tpr + b.tpr) * (b.fpr - a.fpr) AS area
    FROM
        tpr_fpr a
    JOIN
        tpr_fpr b
    ON
        a.row_num + 1 = b.row_num
)
(0.9375,)

しかし、インターネット上でみつけた参考文献 [1] をみると上の実装よりもずっとすっきりしていることがわかります。この要因は、参考文献 [2] にあるように集計関数を WINDOW 関数として扱っていなかったことと、LAG 関数を利用していなかったことにあるので、この 2 点を直すと以下になります。ちなみに参考文献 [1] はラベルが 1 か 0 であることを前提にもっとすっきりさせているので参考文献 [1] をみてください。

tpr_fpr AS (
    -- ROCカーブの座標
    SELECT
        score,
        CAST(COUNT(IF(label = 1, 1, NULL)) OVER(ORDER BY score DESC) AS DOUBLE) / COUNT(IF(label = 1, 1, NULL)) OVER() AS tpr,
        CAST(COUNT(IF(label = 0, 1, NULL)) OVER(ORDER BY score DESC) AS DOUBLE) / COUNT(IF(label = 0, 1, NULL)) OVER() AS fpr
    FROM
        data
)

-- 台形則でカーブ下の面積を算出
SELECT
    SUM(area) AS auc
FROM
(
    SELECT
        -- 0.5 * (上底 + 下底) * 高さ
        0.5 * (tpr + LAG(tpr, 1, 0.0) OVER (ORDER BY score DESC)) * (fpr - LAG(fpr, 1, 0.0) OVER (ORDER BY score DESC)) AS area
    FROM
        tpr_fpr
)
(0.9375,)

なお、ROC カーブはスコアが取りうる値が限られている(例えばいくつかの離散値である)ときはデータ数に比べて実質的なカーブの座標の点数は小さくなりますが、後者の実装では必ずデータ数だけの座標(実際には曲線の形を変えない)を出すことにはなります。それを出したくないなら CASE 文を使って必要なときだけ集計 WINDOW 関数すればいいですがそれで処理時間に恩恵があるのかわかりません。では前者の実装はというとデータ数が大きくスコアが連続的なときは JOIN ON TRUE の結果が巨大になりそうですがだからだめかはやってみていないのでわかりません → 追記:やってみたら前者が遅かったです。なんやかんやいう以前に内部で JOIN ON TRUE するオーバーヘッドがすごいんだと思います(適当)。