雑記

この記事シリーズは目的が未達成です。
お気付きの点がありましたらご指摘いただけますと幸いです。

💜

f:id:cookie-box:20211229152010p:plain:w70
  • 前回、最新の Trax (v1.4.1) を導入した上でリポジトリの「Reformer で機械翻訳」なるノートブック(以下)と同じモデルを手元で生成しようとしたところ、バージョンの食い違い(ノートブックは v1.2.3)でハイパーパラメータと重みの設定ファイル(config.gin, model.pkl)がロードできませんでした。

    https://github.com/google/trax/blob/v1.4.1/trax/models/reformer/machine_translation.ipynb

  • しかしよくみると、config.gin とはただのテキストファイルなのですね? であれば、エラーメッセージをみて徐々に config.gin を修正していけば……model.pkl も ModuleNotFoundError: No module named 'trax.history' などとエラーを吐いてきますね。もう Trax のバージョンを下げましょう。この Pipfile で Trax だけ v1.2.3 に下げる分には何のエラーも起きなさそう……ではないですね。pipenv install には失敗しないようなのですが、実際に import trax してみるとエラーが発生します。
  • >>> import trax
    site-packages/trax/layers/__init__.py", line 43, in layer_configure
        return gin.external_configurable(*args, **kwargs)
    TypeError: external_configurable() got an unexpected keyword argument 'blacklist'
    

  • ここでエラーを吐いているのは gin-config の関数なのですから、gin-config のバージョンも下げるべきということなのでしょうか。Pipfile.lock をみると手元の gin-config のバージョンは 0.5.0 です。他方、Trax が v1.2.3 になったのは 2020 年 2 月のようです(参考)。この頃の gin-config のバージョンは…… 0.4.0 でしょうか(参考)。Pipfile にそう指定してみましょう。今度は上手く……いきませんね。今後は gym さんのせいとのことなので gym さんを 0.17.0 に落としましょう。これで import trax が通り、前回の記事の設定ファイルの読み込みが実行できますね。やたら WARNING メッセージが出ますが……。
  • WARNING:root:Argument blacklist is deprecated. Please use denylist.
    WARNING:root:Argument blacklist is deprecated. Please use denylist.
    WARNING:root:Argument blacklist is deprecated. Please use denylist.
    ...
    

  • それではさっそく Reformer のインスタンスを生成しましょう。先のノートブックではビームサーチするデコーダとして生成するので同様にしましょう……Jax がエラーになりますね。
  • import trax
    from trax.models.beam_search import Search
    from tensorflow.compat.v1.io.gfile import GFile
    import pickle
    import gin
    
    gin.parse_config_file('./config.gin')
    with GFile('./model.pkl', 'rb') as f:
        model_weights = pickle.load(f)['weights']
    
    beam_decoder = Search(
        trax.models.Reformer, model_weights,
        beam_size=4, alpha=0.6, eos_id=1, max_decode_len=146,
    )
    
    AttributeError: module 'jax' has no attribute 'api'
    

  • 今度は Jax のバージョンでしょうか? よくみるとノートブック内に以下のようにありましたね。しかしこの版の jaxlib では cuda11 向けのビルドがありませんね。かといって jax だけ下げてもエラーになりますし、cuda102 向けのビルドをインストールしてもエラーになりますね。
  • !gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .
    !gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .

  • ちなみに jax.api は jax 0.2.21 (Sept 23, 2021) で撤去されたようですね。

    Change log — JAX documentation

    撤去といっても、それまでの jax.api.* は jax.* として使用できるようですから、Trax v1.2.3 のコードを手元にチェックアウトして jax.api.* を jax.* に編集すればいいのでしょうか。