雑記

参考文献

  1. GitHub - lucidrains/performer-pytorch: An implementation of Performer, a linear attention-based transformer, in Pytorch(2022年5月7日参照).


performer_pytorch(v1.1.4)は以下のような構造をしているのがわかる。
FastAttention がどう Fast なのかが肝心である。
以下の確認スクリプトとは関係ないが、performer_pytorch.reversible.ReversibleSequence は reformer_pytorch にある同名のクラスと同じなのかわかっていない。

import torch
import torch.nn as nn
from performer_pytorch import Performer
from performer_pytorch.reversible import SequentialSequence
from performer_pytorch.performer_pytorch \
    import PreLayerNorm, SelfAttention, Chunk, FastAttention, FeedForward

model = Performer(
    dim=8,
    depth=3,
    heads=4,
    dim_head=6,
    causal=True
)
x = torch.randn(5, 128, 8)
y = model(x)
assert list(y.shape) == [5, 128, 8]

assert type(model.net) is SequentialSequence
assert type(model.net.layers) is nn.ModuleList
assert len(model.net.layers) == 3

for i in range(3):
    assert type(model.net.layers[i]) is nn.ModuleList
    assert len(model.net.layers[i]) == 2

    assert type(model.net.layers[i][0]) is PreLayerNorm
    assert type(model.net.layers[i][0].norm) is nn.LayerNorm
    assert type(model.net.layers[i][0].fn) is SelfAttention
    assert type(model.net.layers[i][0].fn.fast_attention) is FastAttention
    assert type(model.net.layers[i][0].fn.fast_attention.kernel_fn) is nn.ReLU
    assert type(model.net.layers[i][0].fn.to_q) is nn.Linear
    assert type(model.net.layers[i][0].fn.to_k) is nn.Linear
    assert type(model.net.layers[i][0].fn.to_v) is nn.Linear
    assert type(model.net.layers[i][0].fn.to_out) is nn.Linear
    assert type(model.net.layers[i][0].fn.dropout) is nn.Dropout

    assert type(model.net.layers[i][1]) is PreLayerNorm
    assert type(model.net.layers[i][1].norm) is nn.LayerNorm
    assert type(model.net.layers[i][1].fn) is Chunk
    assert type(model.net.layers[i][1].fn.fn) is FeedForward
    assert type(model.net.layers[i][1].fn.fn.w1) is nn.Linear
    assert type(model.net.layers[i][1].fn.fn.act) is nn.GELU
    assert type(model.net.layers[i][1].fn.fn.dropout) is nn.Dropout
    assert type(model.net.layers[i][1].fn.fn.w2) is nn.Linear