参考文献
- GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
- GitHub - google/trax: Trax — Deep Learning with Clear Code and Speed
CUDA 11.6 + cuDNN v8.2.0 on Ubuntu20.04 on WSL2 への JAX の導入は以下のコマンドでできます [1]。
$ pip install --upgrade pip $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
でも pipenv で管理したいです。上の URL から自分の環境に入ったパッケージを選んで直接指定すれば pipenv でもインストールできます。
$ pipenv install "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.2+cuda11.cudnn82-cp38-none-manylinux2010_x86_64.whl" $ pipenv install jax==0.3.4
そもそも今回 JAX を導入したかった理由が Trax [2] ですが、JAX があれば Trax は以下で導入できます。
$ pipenv install trax==1.4.1
この環境で jax.numpy と trax.fastmath が動くことだけ確認したコードが以下です。ただ DeprecationWarning が出ます → JAX のリポジトリでもそうしていたので pytest.ini で抑制しました。
GitHub - CookieBox26/ML-on-WSL