雑記: WSL2 + CUDA 11.6 に Jax を Pipenv で導入するだけ

参考文献

  1. GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
  2. GitHub - google/trax: Trax — Deep Learning with Clear Code and Speed


CUDA 11.6 + cuDNN v8.2.0 on Ubuntu20.04 on WSL2 への JAX の導入は以下のコマンドでできます [1]。

$ pip install --upgrade pip
$ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

でも pipenv で管理したいです。上の URL から自分の環境に入ったパッケージを選んで直接指定すれば pipenv でもインストールできます。

$ pipenv install "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.2+cuda11.cudnn82-cp38-none-manylinux2010_x86_64.whl"
$ pipenv install jax==0.3.4

そもそも今回 JAX を導入したかった理由が Trax [2] ですが、JAX があれば Trax は以下で導入できます。

$ pipenv install trax==1.4.1


この環境で jax.numpy と trax.fastmath が動くことだけ確認したコードが以下です。ただ DeprecationWarning が出ます → JAX のリポジトリでもそうしていたので pytest.ini で抑制しました。
GitHub - CookieBox26/ML-on-WSL