transformers で学習済みの BERT モデルから固有表現抽出用のモデルインスタンスをつくるまでだけです。
→ GitHub に移行しました。GitHub - CookieBox26/ML: machine learning
コード
import torch from transformers import ( BertConfig, BertTokenizer, BertForTokenClassification, ) def main(): # 各トークンを以下の13クラスのいずれかに分類するような固有表現抽出をしたい. labels = [ 'B-corporation', 'B-creative-work', 'B-group', 'B-location', 'B-person', 'B-product', 'I-corporation', 'I-creative-work', 'I-group', 'I-location', 'I-person', 'I-product', 'O' ] id2label = {i: label for i, label in enumerate(labels)} label2id = {label: i for i, label in enumerate(labels)} # 利用する学習済みBERTモデルの名前を指定する. model_name = 'bert-large-cased' # 学習済みモデルに対応したトークナイザを生成する. tokenizer = BertTokenizer.from_pretrained( pretrained_model_name_or_path=model_name, ) # 学習済みモデルから各トークン分類用モデルのインスタンスを生成する. # 設定する内容にもよるが必ずしも設定オブジェクトを生成して渡す必要はない. model = BertForTokenClassification.from_pretrained( pretrained_model_name_or_path=model_name, id2label=id2label, # 各トークンに対する出力を13次元にしたいのでこれを渡す. ) # 一部の重みが初期化されていませんよという警告が出るが(クラス分類する層が # 初期化されていないのは当然)面倒なので無視する. # print(model) # 24層あるのでプリントすると長い. print('◆ 適当な文章をID列にしてみる.') sentence = 'The Empire State Building officially opened on May 1, 1931.' # BERT に文章を流すとき文頭に特殊トークン [CLS] 、 # 文末に特殊トークン [SEP] が想定されている. # tokenizer.encode() でID列にすると勝手に付加されている. print('◇') ids = tokenizer.encode(sentence) for id_ in ids: token = tokenizer.convert_ids_to_tokens(id_) print(str(id_).ljust(5), tokenizer.convert_ids_to_tokens(id_)) # 先にトークン列が手元にある場合は特殊トークンを明示的に付加する. print('◇') tokens = tokenizer.tokenize(sentence) tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token] for token in tokens: id_ = tokenizer.convert_tokens_to_ids(token) print(str(id_).ljust(5), tokenizer.convert_ids_to_tokens(id_)) print('◆ モデルに流してみる.→ 14トークン×13クラスの予測結果になっている(サイズが).') inputs = torch.tensor([tokenizer.encode(sentence)]) # ID列をテンソル化して渡す. outputs = model(inputs) print(outputs[0].size())
出力
# ここで一部の重みが初期化されていませんよという警告が出るが気にしないことにする. ◆ 適当な文章をID列にしてみる. ◇ 101 [CLS] 1109 The 2813 Empire 1426 State 4334 Building 3184 officially 1533 opened 1113 on 1318 May 122 1 117 , 3916 1931 119 . 102 [SEP] ◇ 101 [CLS] 1109 The 2813 Empire 1426 State 4334 Building 3184 officially 1533 opened 1113 on 1318 May 122 1 117 , 3916 1931 119 . 102 [SEP] ◆ モデルに流してみる.→ 14トークン×13クラスの予測結果になっている(サイズが). torch.Size([1, 14, 13])
Python環境
[[source]] name = "pypi" url = "https://pypi.org/simple" verify_ssl = true [packages] torch = "==1.4.0" transformers = "==3.1.0" [requires] python_version = "3.7.0"