雑記: Flax でのモデル生成と適用

参考文献

  1. Flax 2 ("Linen") - Colaboratory(2022年4月27日参照).
  2. JAX/Flaxを使ってMNISTを学習させてみる | TC3株式会社(2022年4月27日参照).


Flax でニューラルモデルにデータを流すのは以下のようにやることになるということです。
JAX では乱数を生成する度にサブキーをちぎり出すので乱数を生成する度にサブキーをちぎり出しています。
「初期パラメータを得ます」というのが聞き慣れない手順だと思います。このときに入力も渡します。最初に出力次元数しか渡さなかったのでそら入力次元数知らんやろなあというのはそうですが、ほならなんで最初に入力次元数を宙ぶらりんにしておくのかわかりません。そのうちわかるんだろうと思います。なお、最後の次元数が 5 の多次元配列であれば何でもよいようにみえます。

import jax
import jax.numpy as jnp
from flax import linen


key = jax.random.PRNGKey(26)  # 乱数のキーを入手する

# モデル(出力次元数が 2 の全結合層)を生成する
model_flax = linen.Dense(features=2)

# ダミー入力を生成する
key, sub_key = jax.random.split(key, 2)
x = jax.random.uniform(sub_key, (4, 5))  # バッチサイズ 4、入力次元数 5

# ダミー入力を渡すことでモデルの初期パラメータを得る
key, sub_key = jax.random.split(key, 2)
init_variables = model_flax.init(sub_key, x)
# init_variables = model_flax.init(sub_key, jnp.ones([4, 5]))  # これでも結果同じ
# init_variables = model_flax.init(sub_key, jnp.ones([8, 5]))  # これでも結果同じ
# init_variables = model_flax.init(sub_key, jnp.ones([3, 4, 5]))  # 何ならこれでも結果同じ

# モデルを適用する(パラメータ及び入力を渡す)
y = model_flax.apply(init_variables, x)
print(y)
[[-0.46102825  0.67549485]
 [-0.39238137 -0.3713517 ]
 [ 0.07502306 -0.2957006 ]
 [-0.54493564  0.1820691 ]]