torch.nn.Transformer を使用したことがありません……ドキュメントによると「BERT が構築できます」とのことなので transformers.BertModel と同様のモデルであると思うのですが、百聞は一見にしかず、torch.nn.Transformer のソースコードをみてみましょう。
pytorch/transformer.py at v1.10.1 · pytorch/pytorch · GitHub
ソースコードをみると、Transformer クラスは encoder と decoder を内包し、encoder の実装は TransformerEncoder クラスであり、その構造は以下のような感じですね。入力配列にエンコーダ層を繰り返し適用した後、最後の仕上げに LayerNorm します(省略可能ですが)。TransformerEncoder クラス
エンコーダ層を 6 つかいてみたのは Transformer クラスのデフォルト設定値が 6 だったのでかいてみただけであり、無論変更可能です。リザバートランスフォーマーではこの繰り返されるエンコーダ層を 1 つおきにリザバーにするなりするのでしょう。- layers メンバ: ModuleList
- 0: TransformerEncoderLayer
- 1: TransformerEncoderLayer
- 2: TransformerEncoderLayer
- 3: TransformerEncoderLayer
- 4: TransformerEncoderLayer
- 5: TransformerEncoderLayer
- norm メンバ: LayerNorm
ではそのエンコーダ層、TransformerEncoderLayer クラスの実装をみましょう。入力データへの適用順にニューラル層であるメンバを列挙すると以下でしょうか。
TransformerEncoderLayer クラス(2ブロックからなる)
前3つと後ろ5つのメンバを気持ち離してかいたのは前者が self-attention block、後者が feed forward block と名付けられているからです。全結合をもつメンバには重みの行列とバイアスのベクトルのサイズをメモしておきました(Transformer クラスをデフォルト設定値でインスタンス化した場合です)。関係ないですが、デフォルト設定値では構造は BERT と一致していませんね。bert-base-uncased でもモデル次元は 768 次元ですから。- self_attn メンバ: MultiheadAttention
- in_proj_weight:[1536, 512],in_proj_bias:[1536]
- out_proj.weight:[512, 512],out_proj.bias:[512]
- dropout1 メンバ: Dropout
- norm1 メンバ: LayerNorm(「このブロックの入力に直前の層まで適用したもの + このブロックの入力」に適用)
- linear1 メンバ: Linear
- weight:[2048, 512],bias:[2048]
- dropout メンバ: Dropout
- linear2 メンバ: Linear
- weight:[512, 2048],bias:[512]
- dropout2 メンバ: Dropout
- norm2 メンバ: LayerNorm(「このブロックの入力に直前の層まで適用したもの + このブロックの入力」に適用)
それで、エンコーダ層でモデルを最初に待ち受ける MultiheadAttention クラスはソースファイルが変わって以下にありますね。
pytorch/activation.py at v1.10.1 · pytorch/pytorch · GitHub
さらに、この層の forward() の実体は F.multi_head_attention_forward であり、以下のソースですね。pytorch/functional.py at v1.10.1 · pytorch/pytorch · GitHub
上のソースを参考に MultiheadAttention.forward() を自分で検算してみましょう。ML/test_multiheadattention.py at 5d2fe82a2bdcdbf932a2ea93b12c55a43e796220 · CookieBox26/ML · GitHub
MultiheadAttention.forward() はセルフアテンション適用後の単語列に加えてセルフアテンション行列も返却してくれるのでその値を検算しました。今回は 8 ヘッドあるのでセルフアテンション行列は本来 8 個あるはずなんですが、返却されるのは全てのヘッドの平均を取ったもののようですね。これを得て何につかうのでしょうか。文章を処理してどの単語からどの単語にアテンションがあるか可視化するなどは全ヘッドの平均でみて差し支えないと思いますが。