- 例えばアヤメの 3 クラス分類なら LightGBM は 1 ラウンドごとに 3 本の木を構築し、それぞれの木は「セトーサらしさのスコア」「バーシカラーらしさのスコア」「バージニカらしさのスコア」を前ラウンドまでの累積スコアからどれだけ変動させるかを出力する。最終的に各ラベルらしい確率を出力するときは、各ラベルらしさの累積スコアが Softmax される。
- 木の訓練時には交差エントロピー損失が最も減る分岐を探索するが、このとき厳密に損失を計算するのではなく、前ラウンドまでの累積スコアの周りでの 2 次までのテイラー近似 (のヘッセ行列の非対角成分を無視してその分補正したもの) で損失の減少量 (ゲイン) を見積もる。
- 訓練時に同ラウンドの他クラスの結果を反映するなどはしない=同ラウンド内の木の根の分岐のゲイン減少量の起点は同じ (前ラウンド終了時点) である。
- Parameters Tuning — LightGBM 4.6.0.99 documentation: 「ゲイン」とはその分岐を追加することで損失を減じられる量のこととある。
- GitHub - microsoft/LightGBM: A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks.: LightGBM のソースであり、不明点があればこの実装が正である。手元の Ubuntu 環境ではデバッグプリントを入れてビルド・実行することができた。
- 6 Gradient Boosted Decision Trees – Machine Learning for Economics: 勾配ブースティング決定木の損失の 2 次近似について記述がある。
multiclass (予測確率の対数損失 = 交差エントロピーの全サンプル平均) とし、また簡単のため num_leaves=2, n_estimators=2, learning_rate=1 にしました。1 本の木には最大でも葉は 2 枚のみ (1 回分岐するのみ)、訓練は 2 ラウンドのみ、学習率 1 ということですね。▼ train.py (クリックして展開)
import lightgbm as lgb from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report import pandas as pd import numpy as np def train_model(x_train, y_train, x_test, y_test): # LGBM モデルインスタンス作成 # https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html#lightgbm-lgbmclassifier # https://lightgbm.readthedocs.io/en/latest/Parameters.html model = lgb.LGBMClassifier( objective='multiclass', random_state=42, verbosity=-1, num_leaves=2, # 1 本の木には最大何枚の葉 (最低でも2枚にしないとエラー) n_estimators=2, # 最大ラウンド数 learning_rate=1, # 学習率 ) # フィット # https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html#lightgbm.LGBMClassifier.fit model.fit( x_train, y_train, eval_set=(x_test, y_test), eval_metric='multi_logloss', callbacks=[ lgb.early_stopping(stopping_rounds=3), # 3 ラウンド連続で損失が改善しなかったら停止 lgb.log_evaluation(1), ] ) return model def report(y_true, y_pred, target_names): # 分類結果レポート # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html d = classification_report(y_true, y_pred, target_names=target_names, output_dict=True) print('accuracy:', d['accuracy']) del d['accuracy'] print(pd.DataFrame(d).T) # accuracy 以外の値は同じキーを持つ辞書 if __name__ == '__main__': # アイリスデータセットのロード # https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html data = load_iris(as_frame=True) # データフレームとして取得 feat_names = ['sl', 'sw', 'pl', 'pw'] # 特徴量名を短縮 data.data.columns = feat_names # トレインテストスプリット (40% をテストデータに) # https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.4, random_state=42) assert x_train.shape == (90, 4) assert x_test.shape == (60, 4) # モデル訓練・予測・レポート・保存 model = train_model(x_train, y_train, x_test, y_test) y_pred = model.predict(x_test) report(y_test, y_pred, data.target_names) model.booster_.save_model('iris_lgbm.txt') # テストデータの予測の詳細および損失の検算 df = x_test.copy() df['true'] = y_test.map(lambda x: data.target_names[x]).values df_proba = pd.DataFrame(model.predict_proba(x_test)) df_proba.columns = [f'q({target_name})' for target_name in data.target_names] df = pd.concat([df.reset_index(drop=True), df_proba.reset_index(drop=True)], axis=1) df['q(true)'] = df.apply(lambda row: row['q({})'.format(row['true'])], axis=1) df['log(q(true))'] = df['q(true)'].map(np.log) print('multi_logloss:', -df['log(q(true))'].mean()) # 損失の検算 print(df.head())
[1] valid_0's multi_logloss: 0.347051 Training until validation scores don't improve for 3 rounds [2] valid_0's multi_logloss: 0.157009 Did not meet early stopping. Best iteration is: [2] valid_0's multi_logloss: 0.157009 accuracy: 0.9333333333333333 precision recall f1-score support setosa 1.000000 0.956522 0.977778 23.0 versicolor 0.894737 0.894737 0.894737 19.0 virginica 0.894737 0.944444 0.918919 18.0 macro avg 0.929825 0.931901 0.930478 60.0 weighted avg 0.935088 0.933333 0.933824 60.0 multi_logloss: 0.15700909674521454 sl sw pl pw true q(setosa) q(versicolor) q(virginica) q(true) log(q(true)) 0 6.1 2.8 4.7 1.2 versicolor 0.058272 0.869997 0.071730 0.869997 -0.139265 1 5.7 3.8 1.7 0.3 setosa 0.962649 0.027779 0.009572 0.962649 -0.038066 2 7.7 2.6 6.9 2.3 virginica 0.014405 0.051828 0.933767 0.933767 -0.068528 3 6.0 2.9 4.5 1.5 versicolor 0.058272 0.869997 0.071730 0.869997 -0.139265 4 6.8 2.8 4.8 1.4 versicolor 0.058272 0.869997 0.071730 0.869997 -0.139265
正解率が 93.3 % なので、テストデータ 60 サンプルのうち 56 サンプルを正しい品種と予測していることが窺えますね。訓練終了時の評価損失が 0.157009 とロギングされていますが、これはテストデータの各サンプルの「予測分布に対する真分布の交差エントロピー」の平均、つまり、以下に等しいことが検算できました。
model.booster_.save_model('iris_lgbm.txt') で出力された iris_lgbm.txt をみてみました。ここに出力されていた Tree=0 の情報は以下でした。この Tree=0 の分岐は「花弁の長さが 1.8cm 未満か」というものになっています。確認すると、訓練データのうちこれを満たす 26 サンプルは全てセトーサであり、27 サンプルのセトーサのうちほとんどを分離できるので、これはよい分岐なのでしょう。実際この分岐の「ゲイン」は split_gain=56.875 と何やら大きい値となっています。LightGBM のドキュメントのパラメータチューニングのページによると、「ゲイン」とはその分岐を追加することで損失を減じられる量のことであるそうですね。LightGBM は「ゲイン」が最大になる分岐 (どの特徴量のどの閾値) を選択するのだと。
Tree=0 num_leaves=2 num_cat=0 split_feature=2 # 花弁長 split_gain=56.875 threshold=1.8 # 閾値 decision_type=2 left_child=-1 right_child=-2 leaf_value=1.0182493968717188 -2.1067506267809168 leaf_weight=8.1899999380111712 20.159999847412109 leaf_count=26 64 internal_value=-1.20397 internal_weight=28.35 internal_count=90 is_linear=0 shrinkage=1
print(- 27 * np.log(27 / 90) - 31 * np.log(31 / 90) - 32 * np.log(32 / 90)) # 98.638 print(- 1 * np.log(1 / 64) - 31 * np.log(31 / 64) - 32 * np.log(32 / 64)) # 48.811
しかし、これで減らせる交差エントロピーは 50 くらいです。「ゲイン」56.875 と近い気もしますが、そうはいっても結構誤差があります。それも、理論上減じ得る値より「ゲイン」の方が大きいのです。
わからない点は他にもあります。出力された iris_lgbm.txt 中に出てくる 6 箇所の split_gain を足し上げると 140 ちょっとになりますが (訓練後のモデルに model.booster_.feature_importance(importance_type='gain') を適用することによって出せる特徴量ごとのゲインの和もこの値ですね)、元々の損失が 98.638 なので 140 も減らすことはできないはずです。損失が 0 になった時点で完璧な予測が達成されますから。
……そういうわけで不明点をまとめますと、なぜ分岐たちが主張する「ゲイン」は、彼らが実際に減らすであろう損失より大きいのでしょうか? 分岐たちは自らが如何に優秀な分岐であるかを盛っているのですか? 決定木の分岐に採用されるのも人間の就職活動と同じなのですか? さすれば私は LightGBM さんに彼らの欺瞞を伝えなければ……。
model.predict_proba(x_train) で予測確率を出力)。| 萼長 | 萼幅 | 花長 | 花幅 | 正解ラベル | q(setosa) | q(versicolor) | q(virginica) | |
|---|---|---|---|---|---|---|---|---|
| 0 | 6.3 | 2.7 | 4.9 | 1.8 | virginica | 0.012 | 0.185 | 0.803 |
| 1 | 4.8 | 3.4 | 1.9 | 0.2 | setosa | 0.252 | 0.691 | 0.057 |
| 2 | 5.0 | 3.0 | 1.6 | 0.2 | setosa | 0.963 | 0.028 | 0.010 |
| 3 | 5.1 | 3.3 | 1.7 | 0.5 | setosa | 0.963 | 0.028 | 0.010 |
| 4 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor | 0.058 | 0.870 | 0.072 |
| 萼長 | 萼幅 | 花長 | 花幅 | 正解ラベル | s(setosa) | s(versicolor) | s(virginica) | q(setosa) | q(versicolor) | q(virginica) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 6.3 | 2.7 | 4.9 | 1.8 | virginica | -1.204 | -1.066 | -1.034 | 0.3 | 0.344 | 0.356 |
| 1 | 4.8 | 3.4 | 1.9 | 0.2 | setosa | -1.204 | -1.066 | -1.034 | 0.3 | 0.344 | 0.356 |
| 2 | 5.0 | 3.0 | 1.6 | 0.2 | setosa | -1.204 | -1.066 | -1.034 | 0.3 | 0.344 | 0.356 |
| 3 | 5.1 | 3.3 | 1.7 | 0.5 | setosa | -1.204 | -1.066 | -1.034 | 0.3 | 0.344 | 0.356 |
| 4 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor | -1.204 | -1.066 | -1.034 | 0.3 | 0.344 | 0.356 |
assert np.isclose(np.log(27 / 90), -1.204, atol=1e-03) # セトーサの存在比の対数 assert np.isclose(np.log(31 / 90), -1.066, atol=1e-03) # バーシカラーの存在比の対数 assert np.isclose(np.log(32 / 90), -1.034, atol=1e-03) # バージニカの存在比の対数
そうか、これから木によって訓練サンプルを分岐させて、「この葉に来たらセトーサらしさのスコアを上げよう」「この葉に来たらバーシカラーらしさのスコアを下げよう」などといったスコア調整をしていくわけですね! ……しかし、その分岐位置をどうやって見出すんです? 先ほどの副部長の言によりますと、地道にあらゆる分岐の損失減少幅を計算してみることはしないのですよね?
Tree=0 の split_gain=56.875 が再現できます。
q_setosa = 27 / 90 # 全体の 1, 2 回偏微分値の和 g = 27 * (- 1 + q_setosa) + 63 * q_setosa # 0 h = 90 * (3 / 2) * q_setosa * (1 - q_setosa) # 28.35 # 左の葉の 1, 2 回偏微分値の和 (花弁の長さが 1.8cm 未満の 26 サンプル) (全てセトーサ) g_L = 26 * (- 1 + q_setosa) # -18.2 h_L = 26 * (3 / 2) * q_setosa * (1 - q_setosa) # 8.19 # 右の葉の 1, 2 回偏微分値の和 (花弁の長さが 1.8cm 以上の 64 サンプル) (セトーサ 1 サンプルとその他) g_R = 1 * (- 1 + q_setosa) + 63 * q_setosa # 18.2 h_R = 64 * (3 / 2) * q_setosa * (1 - q_setosa) # 20.16 # 「花弁の長さが 1.8cm 未満か」という分岐のゲイン gain = g_L ** 2 / h_L + g_R ** 2 / h_R - g ** 2 / h print(gain) # 56.875
それに、この左右の葉の 2 回偏微分値 h_L, h_R は leaf_weight=8.1899999380111712 20.159999847412109 に一致していますね。
leaf_value=1.0182493968717188 -2.1067506267809168 に一致しているよ。第 2 ラウンド以降の leaf_value は「スコア変化幅」だけになっているけど、第 1 ラウンドだけは「初期値 + スコア変化幅」になっているみたい。勾配ブースティング決定木の記述方法としてこうするのが一般的なのかな?
print(- 1.204 - g_L / h_L) # 1.018 print(- 1.204 - g_R / h_R) # -2.107
ともあれ、実際に「セトーサらしさのスコア」を更新しよう。Tree=0 の分岐の結果、以下のようになる。
| 萼長 | 萼幅 | 花長 | 花幅 | 正解ラベル | s(setosa) | s(versicolor) | s(virginica) | q(setosa) | q(versicolor) | q(virginica) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 6.3 | 2.7 | 4.9 | 1.8 | virginica | -2.107 | -1.066 | -1.034 | 0.148 | 0.419 | 0.433 |
| 1 | 4.8 | 3.4 | 1.9 | 0.2 | setosa | -2.107 | -1.066 | -1.034 | 0.148 | 0.419 | 0.433 |
| 2 | 5.0 | 3.0 | 1.6 | 0.2 | setosa | 1.018 | -1.066 | -1.034 | 0.798 | 0.099 | 0.103 |
| 3 | 5.1 | 3.3 | 1.7 | 0.5 | setosa | 1.018 | -1.066 | -1.034 | 0.798 | 0.099 | 0.103 |
| 4 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor | -2.107 | -1.066 | -1.034 | 0.148 | 0.419 | 0.433 |
Tree=0 構築時点でこの値を計算するという意味ではないよ。Tree=1 では「バーシカラーらしさのスコア」を更新するよ。面白いことに選ばれた分岐の特徴量・閾値が Tree=0 と全く同じだね。 Tree=0 では「花弁の長さが 1.8cm 未満ならセトーサらしい」が見出され、Tree=1 では「花弁の長さが 1.8cm 未満ならバーシカラーらしくない」が見出されたわけだ。この結果、以下になる。| 萼長 | 萼幅 | 花長 | 花幅 | 正解ラベル | s(setosa) | s(versicolor) | s(virginica) | q(setosa) | q(versicolor) | q(virginica) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 6.3 | 2.7 | 4.9 | 1.8 | virginica | -2.107 | -0.653 | -1.034 | 0.122 | 0.522 | 0.356 |
| 1 | 4.8 | 3.4 | 1.9 | 0.2 | setosa | -2.107 | -0.653 | -1.034 | 0.122 | 0.522 | 0.356 |
| 2 | 5.0 | 3.0 | 1.6 | 0.2 | setosa | 1.018 | -2.083 | -1.034 | 0.852 | 0.038 | 0.109 |
| 3 | 5.1 | 3.3 | 1.7 | 0.5 | setosa | 1.018 | -2.083 | -1.034 | 0.852 | 0.038 | 0.109 |
| 4 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor | -2.107 | -0.653 | -1.034 | 0.122 | 0.522 | 0.356 |
Tree=1 の構築にあたって、Tree=0 による「セトーサらしさのスコア」の更新に伴う予測確率の更新は「反映されていない」よ。もし反映されていたらすべてのバーシカラーサンプルはバーシカラー予測確率が 0.419 になるけど、それだと右の葉の leaf_weight が 23.37 になるはずだから。ラウンドごとに各ラベルらしさのスコア更新が独立に走るんだから、Tree=0 と Tree=1 と Tree=2 のゲインを足せないのがわかるよね。3 人がバトンタッチしながら下山しているんじゃなくて、3 人同じ地点から下山して「自分は何 m 下りた」といっている状態だからね。▼ Tree=1 の情報と検算 (クリックして展開)
Tree=1 num_leaves=2 num_cat=0 split_feature=2 split_gain=12.8072 threshold=1.8 decision_type=2 left_child=-1 right_child=-2 leaf_value=-2.0827716810239876 -0.65268688567404975 leaf_weight=8.8062959909439105 21.677036285400391 leaf_count=26 64 internal_value=-1.06582 internal_weight=30.4833 internal_count=90 is_linear=0 shrinkage=1
q_versi = 31 / 90 g = 31 * (- 1 + q_versi) + 59 * q_versi # 0 g_L = 26 * q_versi # 8.9556 g_R = 31 * (- 1 + q_versi) + 33 * q_versi # -8.9556 h = 90 * (3 / 2) * q_versi * (1 - q_versi) # 30.483 h_L = 26 * (3 / 2) * q_versi * (1 - q_versi) # 8.806 h_R = 64 * (3 / 2) * q_versi * (1 - q_versi) # 21.677 gain = g_L ** 2 / h_L + g_R ** 2 / h_R - g ** 2 / h print(gain) # 12.8072 print(- 1.066 - g_L / h_L) # -2.083 print(- 1.066 - g_R / h_R) # -0.653
続く
Tree=2 では「バージニカらしさのスコア」を更新するわけだけど、選ばれた分岐は「花弁の幅が 1.55cm 未満ならバージニカらしくない」だね。ここまででスコアは以下になるよ。| 萼長 | 萼幅 | 花長 | 花幅 | 正解ラベル | s(setosa) | s(versicolor) | s(virginica) | q(setosa) | q(versicolor) | q(virginica) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 6.3 | 2.7 | 4.9 | 1.8 | virginica | -2.107 | -0.653 | 0.576 | 0.050 | 0.215 | 0.735 |
| 1 | 4.8 | 3.4 | 1.9 | 0.2 | setosa | -2.107 | -0.653 | -1.966 | 0.155 | 0.666 | 0.179 |
| 2 | 5.0 | 3.0 | 1.6 | 0.2 | setosa | 1.018 | -2.083 | -1.966 | 0.913 | 0.041 | 0.046 |
| 3 | 5.1 | 3.3 | 1.7 | 0.5 | setosa | 1.018 | -2.083 | -1.966 | 0.913 | 0.041 | 0.046 |
| 4 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor | -2.107 | -0.653 | -1.966 | 0.155 | 0.666 | 0.179 |
model.booster_.feature_importance(importance_type='gain') とすると各特徴量が稼いだゲインをみせてくれますが、それを踏まえてみる必要がありそうですね。Tree=5 の予測確率は最終結果に等しくなっているね。Tree=3
| 萼長 | 萼幅 | 花長 | 花幅 | 正解ラベル | s(setosa) | s(versicolor) | s(virginica) | q(setosa) | q(versicolor) | q(virginica) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 6.3 | 2.7 | 4.9 | 1.8 | virginica | -2.850 | -0.653 | 0.576 | 0.025 | 0.221 | 0.755 |
| 1 | 4.8 | 3.4 | 1.9 | 0.2 | setosa | -1.156 | -0.653 | -1.966 | 0.323 | 0.534 | 0.143 |
| 2 | 5.0 | 3.0 | 1.6 | 0.2 | setosa | 1.969 | -2.083 | -1.966 | 0.964 | 0.017 | 0.019 |
| 3 | 5.1 | 3.3 | 1.7 | 0.5 | setosa | 1.969 | -2.083 | -1.966 | 0.964 | 0.017 | 0.019 |
| 4 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor | -2.850 | -0.653 | -1.966 | 0.081 | 0.725 | 0.195 |
| 萼長 | 萼幅 | 花長 | 花幅 | 正解ラベル | s(setosa) | s(versicolor) | s(virginica) | q(setosa) | q(versicolor) | q(virginica) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 6.3 | 2.7 | 4.9 | 1.8 | virginica | -2.850 | -0.147 | 0.576 | 0.021 | 0.320 | 0.659 |
| 1 | 4.8 | 3.4 | 1.9 | 0.2 | setosa | -1.156 | -0.147 | -1.966 | 0.239 | 0.655 | 0.106 |
| 2 | 5.0 | 3.0 | 1.6 | 0.2 | setosa | 1.969 | -1.577 | -1.966 | 0.954 | 0.028 | 0.019 |
| 3 | 5.1 | 3.3 | 1.7 | 0.5 | setosa | 1.969 | -1.577 | -1.966 | 0.954 | 0.028 | 0.019 |
| 4 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor | -2.850 | -0.147 | -1.966 | 0.054 | 0.814 | 0.132 |
| 萼長 | 萼幅 | 花長 | 花幅 | 正解ラベル | s(setosa) | s(versicolor) | s(virginica) | q(setosa) | q(versicolor) | q(virginica) | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 6.3 | 2.7 | 4.9 | 1.8 | virginica | -2.850 | -0.147 | 1.322 | 0.012 | 0.185 | 0.803 |
| 1 | 4.8 | 3.4 | 1.9 | 0.2 | setosa | -1.156 | -0.147 | -2.642 | 0.252 | 0.691 | 0.057 |
| 2 | 5.0 | 3.0 | 1.6 | 0.2 | setosa | 1.969 | -1.577 | -2.642 | 0.963 | 0.028 | 0.010 |
| 3 | 5.1 | 3.3 | 1.7 | 0.5 | setosa | 1.969 | -1.577 | -2.642 | 0.963 | 0.028 | 0.010 |
| 4 | 5.6 | 2.7 | 4.2 | 1.3 | versicolor | -2.850 | -0.147 | -2.642 | 0.058 | 0.870 | 0.072 |