お気付きの点がありましたらご指摘いただけますと幸いです。
💜
- Numpy の疑似乱数は、同じシードの値をセットして順に呼び出していけば、同じ値の列が生成されることが保証されるとのことです(もちろん、前回は正規分布にしたがう乱数を生成したが、今回は二項分布にしたがう乱数を生成したでは同じ値にならないですよ!)。これは実験を再現するのにとても便利なのではないのでしょうか。一体何の不満が?
- 並列化……確かに、メイン関数から関数 A と関数 B を並列に呼び出すときは厄介ですね。仮にどちらの関数内でも 10 個の疑似乱数を生成して利用するとき、どちらの関数が何番目の乱数生成を踏むか保証されません。Numpy 方式ではこのような場合に再現性のあるコードをかけないです。乱数生成がグローバルな状態に基づくということなので、関数 A と関数 B のそれぞれで改めてシードをセットするなどしても駄目でしょう。
- だから Jax では状態を明示的に管理し、それが jax.random.PRNGKey(seed) で生成されるキーなのですね。 Numpy ではこうですよね。
- ユーザは最初に適当なシードを与えることでグローバルな状態を初期化する。
- ユーザが疑似乱数生成を呼び出すたびにグローバルな状態が勝手に変化する。
- ユーザは最初に適当なシードを与えることでキーを得る。
- ユーザは疑似乱数生成を呼び出すときキーを渡さなければならない。同じキーなら同じ値が生成される。
- これだと複数回乱数を生成したいときどうするのかとなるが、キーは split して好きな個数に分裂させることができる(分裂させたら古いキーは捨てる)。
- なので基本的には疑似乱数を生成する直前に key, subkey = random.split(key) して subkey を使い捨てていくことになるはずである。
- 並列処理時には、例えば処理の途中箇所を4枚の GPU で処理するとき、その箇所に入る手前で切り分けたキーをさらに 4 つに切り分けてそれぞれの GPU に与えることになる。