雑記: Reformer のモデル構造

参考文献

  1. GitHub - lucidrains/reformer-pytorch: Reformer, the efficient Transformer, in Pytorch(2022年4月28日参照).


reformer_pytorch(v1.4.4)は以下のような構造をしているのがわかる。
LSHSelfAttention が何なのかというのが肝心であるが、通常の Transformer と違い Q への写像と V への写像が同じ nn.Linear であって LSHAttention というアテンションの仕方をする(だからそれが何なのか)。

import torch
import torch.nn as nn
from reformer_pytorch import Reformer, LSHSelfAttention
from reformer_pytorch.reversible import ReversibleSequence
from reformer_pytorch.reversible import ReversibleBlock, IrreversibleBlock
from reformer_pytorch.reversible import Deterministic
from reformer_pytorch.reformer_pytorch import PreNorm, Chunk

model = Reformer(
    dim=8,
    depth=3,
    heads=4,
    lsh_dropout=0.0,
    causal=True
)
x = torch.randn(5, 128, 8)
y = model(x)
assert list(y.shape) == [5, 128, 8]

assert type(model.layers) is ReversibleSequence
assert type(model.layers.blocks) is nn.ModuleList
assert type(model.layers.irrev_blocks) is nn.ModuleList

assert len(model.layers.blocks) == 3
for i in range(3):
    assert type(model.layers.blocks[i]) is ReversibleBlock
    assert type(model.layers.blocks[i].f) is Deterministic
    assert type(model.layers.blocks[i].f.net) is PreNorm
    assert type(model.layers.blocks[i].f.net.norm) is nn.LayerNorm
    assert type(model.layers.blocks[i].f.net.fn) is LSHSelfAttention
    assert type(model.layers.blocks[i].g) is Deterministic
    assert type(model.layers.blocks[i].g.net) is PreNorm
    assert type(model.layers.blocks[i].g.net.norm) is nn.LayerNorm
    assert type(model.layers.blocks[i].g.net.fn) is Chunk

assert len(model.layers.irrev_blocks) == 3
for i in range(3):
    assert type(model.layers.irrev_blocks[i]) is IrreversibleBlock
    assert type(model.layers.irrev_blocks[i].f) is PreNorm
    assert type(model.layers.irrev_blocks[i].f.norm) is nn.LayerNorm
    assert type(model.layers.irrev_blocks[i].f.fn) is LSHSelfAttention
    assert type(model.layers.irrev_blocks[i].g) is PreNorm
    assert type(model.layers.irrev_blocks[i].g.norm) is nn.LayerNorm
    assert type(model.layers.irrev_blocks[i].g.fn) is Chunk