参考文献
- GitHub - lucidrains/reformer-pytorch: Reformer, the efficient Transformer, in Pytorch(2022年4月28日参照).
reformer_pytorch(v1.4.4)は以下のような構造をしているのがわかる。
LSHSelfAttention が何なのかというのが肝心であるが、通常の Transformer と違い Q への写像と V への写像が同じ nn.Linear であって LSHAttention というアテンションの仕方をする(だからそれが何なのか)。
import torch import torch.nn as nn from reformer_pytorch import Reformer, LSHSelfAttention from reformer_pytorch.reversible import ReversibleSequence from reformer_pytorch.reversible import ReversibleBlock, IrreversibleBlock from reformer_pytorch.reversible import Deterministic from reformer_pytorch.reformer_pytorch import PreNorm, Chunk model = Reformer( dim=8, depth=3, heads=4, lsh_dropout=0.0, causal=True ) x = torch.randn(5, 128, 8) y = model(x) assert list(y.shape) == [5, 128, 8] assert type(model.layers) is ReversibleSequence assert type(model.layers.blocks) is nn.ModuleList assert type(model.layers.irrev_blocks) is nn.ModuleList assert len(model.layers.blocks) == 3 for i in range(3): assert type(model.layers.blocks[i]) is ReversibleBlock assert type(model.layers.blocks[i].f) is Deterministic assert type(model.layers.blocks[i].f.net) is PreNorm assert type(model.layers.blocks[i].f.net.norm) is nn.LayerNorm assert type(model.layers.blocks[i].f.net.fn) is LSHSelfAttention assert type(model.layers.blocks[i].g) is Deterministic assert type(model.layers.blocks[i].g.net) is PreNorm assert type(model.layers.blocks[i].g.net.norm) is nn.LayerNorm assert type(model.layers.blocks[i].g.net.fn) is Chunk assert len(model.layers.irrev_blocks) == 3 for i in range(3): assert type(model.layers.irrev_blocks[i]) is IrreversibleBlock assert type(model.layers.irrev_blocks[i].f) is PreNorm assert type(model.layers.irrev_blocks[i].f.norm) is nn.LayerNorm assert type(model.layers.irrev_blocks[i].f.fn) is LSHSelfAttention assert type(model.layers.irrev_blocks[i].g) is PreNorm assert type(model.layers.irrev_blocks[i].g.norm) is nn.LayerNorm assert type(model.layers.irrev_blocks[i].g.fn) is Chunk