雑記: JAX の疑似乱数生成はメルセンヌ・ツイスタではない

参考文献

  1. Python + NumPy: 乱数生成のシード(RandomState)をちゃんと管理する - 物理の駅 by 現役研究者(2022年4月26日参照).
  2. Mersenne Twister - Wikipedia(2022年4月26日参照).
    • メルセンヌ・ツイスタの記事だが日本語版と違って図解と疑似コードがある。
  3. 🔪 JAX - The Sharp Bits 🔪 — JAX documentation(2022年4月26日参照).

関連記事



私は4月22日と23日に関連記事の jax.random.PRNGKey に関する記事をかいていた。それらの記事で参照していた JAX のドキュメントでは「NumPy の疑似乱数生成はグローバルな状態に基づく」とあったが、先ほど参考文献 [1] をみたら NumPy も状態をもつ疑似乱数生成器のインスタンスは生成できるようだ。JAX のドキュメントに騙された。

ただ、説明を簡潔にするためにいちいちかかなかったのかもしれない。また、疑似乱数生成器のインスタンスを生成できれば確かに処理の並列化はできるが、並列箇所でどのように並列数だけの疑似乱数生成器を用意して交代するかは自明ではないのでやはり NumPy と同じ疑似乱数生成方式を脱却しなければならない、ということになるとは思う。さらに、多次元ベクトルの各成分を同時に生成するのにも対応する必要があるらしい。そうしなければ SIMD ハードウェアでのベクトル化に支障があるらしい。SIMD ハードウェアのことを全く知らないが、NumPy 方式では 100 個並んだ箱を乱数で埋めるのに、「左隣の箱を埋めてからでないと埋められない」といったことになるわけだが、SIMD ではきっとそれが許容されないのだろう。

それで、JAX における疑似乱数生成は生成の度に「キーから新しいキーとサブキーを生成してサブキーを消費していく」というものだが、なので JAX の疑似乱数生成器はメルセンヌ・ツイスタ [2] ではない(キーがメルセンヌ・ツイスタの内部ベクトルにはみえないから)。調べると [3] にそうかいてあった。

JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Threefry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.