feirseqを使って、BARTで日本語の文章要約モデルを学習する方法
この記事では、fairseqを使って、BARTで日本語の文章要約を行う手順について解説します。
目次
全体の手順は、以下のようになります。
fairseq(日本語版BARTの事前学習モデル)のインストール
事前学習モデルのダウンロード
Juman++ (2.0.0-rc3) のインストール
データセットの取得と前処理
ファインチューニング
要約実行
定量的評価
ここからはそれぞれについて解説していきたいと思います。
gitからcloneすることでインストールを行います。
pipのみでインストールすることもできますが、様々なバージョンが存在するため、gitから目的に合ったリポジトリをcloneすることを推奨します。
以下は、Google Colablatoryを使って実行した場合の手順です。試したい方は以下のコードを順番にコピーして実行してください。ローカル環境などを使う場合は、適時置き換えてください。
# fairseqインストール(gitから)
!git clone https://github.com/utanaka2000/fairseq.git
%cd fairseq
!git fetch origin
!git checkout japanese_bart_pretrained_model
!git branch
!pip install --editable ./文章要約
# その他必要なライブラリをインストール
!pip install zenhan sentencepiece tensorboard
# fairseq有効化
!echo $PYTHONPATH
import os
os.environ['PYTHONPATH'] = "/env/python"
os.environ['PYTHONPATH'] += ":/content/fairseq/"
!echo $PYTHONPATH
# インストールの確認
!pip show fairseq
2. 事前学習モデルのダウンロード
以下から使用するモデルを一つ選択します。ダウンロードされるフォルダには、事前学習済みBARTモデル(bart_model.pt)、センテンスピースモデル(sp.model)、および辞書(dict.txt)が含まれます。
# BART base v1.1 (1.3G)
http://lotus.kuee.kyoto-u.ac.jp/nl-resource/JapaneseBARTPretrainedModel/japanese_bart_base_1.1.tar.gz
# BART large v1.0 (3.6G)
http://lotus.kuee.kyoto-u.ac.jp/nl-resource/JapaneseBARTPretrainedModel/japanese_bart_large_1.0.tar.gz
# BART base v2.0 (1.3G)
http://lotus.kuee.kyoto-u.ac.jp/nl-resource/JapaneseBARTPretrainedModel/japanese_bart_base_2.0.tar.gz
# BART large v2.0 (3.7G)
http://lotus.kuee.kyoto-u.ac.jp/nl-resource/JapaneseBARTPretrainedModel/japanese_bart_large_2.0.tar.gz
BART base v1.1 (1.3G)のダウンロードとtar.gzファイルの解凍方法は以下になります。他のモデルを使用する場合はダウンロードするURLを置き換えてください。
%cd /content
!wget http://lotus.kuee.kyoto-u.ac.jp/nl-resource/JapaneseBARTPretrainedModel/japanese_bart_base_1.1.tar.gz
!tar -zxvf japanese_bart_base_1.1.tar.gz
3. Juman++ (2.0.0-rc3) のインストール
データセットの前処理時に、形態素解析ツールとして利用するJuman++のインストール方法について記載します。
まず、juman++のダウンロードと解凍をします。
# jumanpp-2.0.0-rc3 download
!wget https://github.com/ku-nlp/jumanpp/releases/download/v2.0.0-rc3/jumanpp-2.0.0-rc3.tar.xz
# unzip a file
!tar xvf jumanpp-2.0.0-rc3.tar.xz
次に、juman++は文字数の上限が4096バイトまでに設定されているため、以下の2つのファイルを変更します。
jumanpp-2.0.0-rc3/src/core/analysis/analyzer.h
16,17行目
size_t pageSize = 48 * 1024 * 1024;
size_t maxInputBytes = 48 * 1024;
jumanpp-2.0.0-rc3/src/core/input/stream_reader.h
27,28行目
u64 maxInputLength_ = 48 * 1024;
u64 maxCommentLength_ = 48 * 1024;
最後に、juman++をインストールします。
# build jumanpp
%cd jumanpp-2.0.0-rc3/
!mkdir buildした
!cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/local
!make
# install jumanpp
!sudo make install
4. データセットの取得と前処理
データセットの取得
wikiHowデータセットの使い方 とlivedoorデータセットの使い方 を参考にデータセットを取得します。
前処理用ファイルの作成
取得したデータセットを利用して、以下の6つのファイルを作成します。データセットはsrcとtgtの二つに分けてtrain.src, train.tgtのように作成します。srcには要約前の文章、tgtには要約後の文章が入り、srcのN行目とtgtのN行目が一つのペアとなるようにします。
train_src.txt
train_tgt.txt
val_src.txt
val_tgt.txt
test_src.txt
test_tgt.txt
livedoorから取得した101,559件のデータの内、以下の条件のデータを抽出する場合のコードを紹介します。
要約前の文字数:1500字以下
要約後の文字数 / 要約前の文字数(圧縮率):5〜50%
import re
livedoor_df = pd.read_csv("livedoor_datasets.csv")
input_df = livedoor_df.copy()
input_df["tgt_length"] = input_df["tgt"].apply(lambda x: len(str(x)))
input_df["src_length"] = input_df["src"].apply(lambda x: len(str(x)))
input_df["rate"] = input_df["tgt_length"] / input_df["src_length"]
mask = input_df['src_length'] <= 1500
input_df = input_df[mask]
mask = input_df['rate'] <= 0.5
input_df = input_df[mask]
mask = input_df['rate'] >= 0.05
input_df = input_df[mask]
次に、抽出した84,635件のデータをtrain, val, testに分割し、記号や特殊文字を削除後、6つのファイルに書き出すコードを紹介します。
train_df = input_df[:80000]
val_df = input_df[80000:82000]
test_df = input_df[82000:]
train_src = train_df["src"].to_list()
train_tgt = train_df["tgt"].to_list()
val_src = val_df["src"].to_list()
val_tgt = val_df["tgt"].to_list()
test_src = test_df["src"].to_list()
test_tgt = test_df["tgt"].to_list()
# 記号や特殊文字を削除する関数
def delete_special_character(str_list):
new_list = []
characters = "■◆☆●\n"
for strength in str_list:
for x in range(len(characters)):
strength = str(strength).replace(characters[x],"")
new_list.append(strength)
return new_list
train_src = delete_special_character(train_src)
train_tgt = delete_special_character(train_tgt)
val_src = delete_special_character(val_src)
val_tgt = delete_special_character(val_tgt)
test_src = delete_special_character(test_src)
test_tgt = delete_special_character(test_tgt)
with open("train_src.txt", 'w') as f:
for d in train_src:
f.write("%s\n" % d)
with open("train_tgt.txt", 'w') as f:
for d in train_tgt:
f.write("%s\n" % d)
with open("val_src.txt", 'w') as f:
for d in val_src:
f.write("%s\n" % d)
with open("val_tgt.txt", 'w') as f:
for d in val_tgt:
f.write("%s\n" % d)
with open("test_src.txt", 'w') as f:
for d in test_src:
f.write("%s\n" % d)
with open("test_tgt.txt", 'w') as f:
for d in test_tgt:
f.write("%s\n" % d)
前処理用の環境変数の設定
# データセットの前処理の設定
%env TRAIN_SRC=train_src.txt
%env TRAIN_TGT=train_tgt.txt
%env VALID_SRC=val_src.txt
%env VALID_TGT=val_tgt.txt
%env TEST_SRC=test_src.txt
%env TEST_TGT=test_tgt.txt
# ダウンロードしたセンテンスピースモデル
%env SENTENCEPIECE_MODEL=japanese_bart_base_1.1/sp.model
# 前処理後のファイルを入れるフォルダ
%env DATASET_DIR=datasets/
# ダウンロードした辞書ファイル
%env DICT=japanese_bart_base_1.1/dict.txt
前処理実行
jaBART_preprocess.pyを使用して、データセットの前処理をします。最初に半幅文字を全幅文字に変換します。次に、Juman++をデータセットに適用し、形態素解析を行います。最後に、形態素解析後のデータセットにセンテンスピースを適用します。
!cat $TRAIN_SRC | python3 fairseq/jaBART_preprocess.py --bpe_model $SENTENCEPIECE_MODEL --bpe_dict $DICT > $DATASET_DIR/train.src-tgt.src
!cat $TRAIN_TGT | python3 fairseq/jaBART_preprocess.py --bpe_model $SENTENCEPIECE_MODEL --bpe_dict $DICT > $DATASET_DIR/train.src-tgt.tgt
!cat $VALID_SRC | python3 fairseq/jaBART_preprocess.py --bpe_model $SENTENCEPIECE_MODEL --bpe_dict $DICT > $DATASET_DIR/valid.src-tgt.src
!cat $VALID_TGT | python3 fairseq/jaBART_preprocess.py --bpe_model $SENTENCEPIECE_MODEL --bpe_dict $DICT > $DATASET_DIR/valid.src-tgt.tgt
!cat $TEST_SRC | python3 fairseq/jaBART_preprocess.py --bpe_model $SENTENCEPIECE_MODEL --bpe_dict $DICT > $DATASET_DIR/test.src-tgt.src
!cat $TEST_TGT | python3 fairseq/jaBART_preprocess.py --bpe_model $SENTENCEPIECE_MODEL --bpe_dict $DICT > $DATASET_DIR/test.src-tgt.tgt
!cp $DICT $DATASET_DIR/dict.src.txt
!cp $DICT $DATASET_DIR/dict.tgt.txt
5. ファインチューニング
fairseq-trainを使用して、新しいモデルをトレーニングします。
ファインチューニング用の環境変数の設定
# ダウンロードした事前学習済みBARTモデル
%env PRETRAINED_MODEL=japanese_bart_base_1.1/bart_model.pt
# bart_baseまたはbart_largeを設定
%env BART=bart_base
# その他保存用フォルダの設定
%env TENSORBOARD_DIR=log/
%env SAVE_MODEL_DIR=save/
%env RESULT=result.txt
ファインチューニング実行
前処理済みデータを利用してファインチューニングを実行します。以下の設定では5epochまで学習を行います。
日本語版BARTの事前学習モデルでは、データのtokenの大きさが1024までと設定されているため、1024を超えるデータを使用するとエラーが発生してしまいます。そこで、ファインチューニング時に、--skip-invalid-size-inputs-valid-test
オプションを指定することで学習できないデータをスキップできます。
今回ファインチューニングに使用するデータセットは、1500字以下のデータのみ使用していますが、tokenの大きさが1024を超えるデータが1件ありました。
!CUDA_VISIBLE_DEVICES=0 fairseq-train $DATASET_DIR --arch $BART --restore-file $PRETRAINED_MODEL \
--save-dir $SAVE_MODEL_DIR --tensorboard-logdir $TENSORBOARD_DIR \
--task translation_from_pretrained_bart --source-lang src --target-lang tgt \
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 --dataset-impl raw \
--optimizer adam --adam-eps 1e-06 --adam-betas '{0.9, 0.98}' --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 \
--warmup-updates 2500 --total-num-update 1000 --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
--max-tokens 1024 --update-freq 5 --save-interval -1 --no-epoch-checkpoints --seed 222 --log-format simple --log-interval 2 \
--reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler --save-interval-updates 10000 \
--ddp-backend no_c10d --max-epoch 5 \
--encoder-normalize-before --decoder-normalize-before --prepend-bos \
--skip-invalid-size-inputs-valid-test
ファインチューニング後モデルの保存
モデル名を指定して、ファインチューニング済みモデルなどの要約実行に必要なファイルを保存します。以下では、”livedoor_5_eopch”という名前で保存します。
%env MODEL_NAME=livedoor_5_eopch
!mkdir models/$MODEL_NAME
!cp -rf $SAVE_MODEL_DIR models/$MODEL_NAME
!cp -rf $DICT models/$MODEL_NAME
!cp -rf $SENTENCEPIECE_MODEL models/$MODEL_NAME
!cp -rf $TENSORBOARD_DIR models/$MODEL_NAME
!cp -rf $DATASET_DIR models/$MODEL_NAME
6. 要約実行
要約用スクリプト
fairseq/fairseq_cli/interactive.py を参考に要約実行用のスクリプトを作成します。
主な変更箇所は以下になります。
get_symbols_to_strip_from_output関数を自作した。
use_cuda = False にした。
zenhanとpyknpとsentencepieceをimportし、encode_fn関数を自作した。
要約実行のコードは以下になります。
テスト用の要約前データとして、test_src.txt を用意します。
from collections import namedtuple
import fileinput
import logging
import math
import sys
import time
import os
import numpy as np
import torch
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import encoders
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
import zenhan
from pyknp import Juman
import sentencepiece
logging.basicConfig(
format='%(name)s | %(message)s',
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
stream=sys.stdout,
)
logger = logging.getLogger('infer')
Batch = namedtuple('Batch', 'ids src_tokens src_lengths constraints')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def get_symbols_to_strip_from_output(generator):
if hasattr(generator, 'symbols_to_strip_from_output'):
return generator.symbols_to_strip_from_output
else:
return {generator.eos}
def buffered_read(input, buffer_size):
buffer = []
with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
for src_str in h:
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
def make_batches(lines, args, task, max_positions, encode_fn):
tokens = [
task.source_dictionary.encode_line(
encode_fn(src_str), add_if_not_exist=False
).long()
for src_str in lines
]
if args.constraints:
constraints_tensor = pack_constraints(batch_constraints)
else:
constraints_tensor = None
lengths = [t.numel() for t in tokens]
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(tokens, lengths, constraints=constraints_tensor),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test
).next_epoch_itr(shuffle=False)
for batch in itr:
ids = batch['id']
src_tokens = batch['net_input']['src_tokens']
src_lengths = batch['net_input']['src_lengths']
constraints = batch.get("constraints", None)
yield Batch(
ids=ids,
src_tokens=src_tokens,
src_lengths=src_lengths,
constraints=constraints,
)
def main(args):
start_time = time.time()
total_translate_time = 0
utils.import_user_module(args)
if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None:
args.max_sentences = 1
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
'--max-sentences/--batch-size cannot be larger than --buffer-size'
# Fix seed for stochastic decoding
if args.seed is not None and not args.no_seed_provided:
np.random.seed(args.seed)
utils.set_torch_seed(args.seed)
use_cuda = False
# Setup task, e.g., translation
task = tasks.setup_task(args)
# Load ensemble
logger.info('loading model(s) from {}'.format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.pa。rrides=eval(args.model_overrides),
task=task,
suffix=getattr(args, "checkpoint_suffix", ""),
)
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
for model in models:
model.prepare_for_inference_(args)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
# Initialize generator
generator = task.build_generator(models, args)
# Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args)
bpe = encoders.build_bpe(args)
jumanpp = Juman()
spm = sentencepiece.SentencePieceProcessor()
spm.Load(args.bpe_model)
return ' '.join([mrph.midasi for mrph in result.mrph_list()])
def bpe_encode(line, spm):
return ' '.join(spm.EncodeAsPieces(line.strip()))
def encode_fn(x):
x = x.strip()
x = zenhan.h2z(x)
x = juman_split(x, jumanpp)
x = bpe_encode(x, spm)
return x
def decode_fn(x):
x = x.translate({ord(i): None for i in ['▁', ' ']})
return x
align_dict = utils.load_align_dict(args.replace_unk)
max_positions = utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
)
if args.constraints:
logger.warning("NOTE: Constrained decoding currently assumes a shared subword vocabulary.")
if args.buffer_size > 1:
logger.info('Sentence buffer size: %s', args.buffer_size)
logger.info('NOTE: hypothesis and token scores are output in base 2')
logger.info('Type the input sentence and press return:')
start_id = 0
# 入力用ファイルを指定する
input_text = 'test_src.txt'
# 出力用の配列
output_texts = []
output_texts_2 = []
output_texts_3 = []
for inputs in buffered_read(input_text, args.buffer_size):
results = []
for batch in make_batches(inputs, args, task, max_positions, encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
},
}
translate_start_time = time.time()
translations = task.inference_step(generator, models, sample, constraints=constraints)
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
if args.constraints:
list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
constraints = list_constraints[i]
results.append((start_id + id, src_tokens_i, hypos,
{
"constraints": constraints,
"time": translate_time / len(translations)
}))
# sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
print(f'Inference time: {info["time"]:.3f} seconds')
# Process top predictions
for hypo_i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)
detok_hypo_str = decode_fn(hypo_str)
score = hypo['score'] / math.log(2) # convert to base 2
if hypo_i == 0:
output_texts.append(detok_hypo_str)
if hypo_i == 1:
output_texts_2.append(detok_hypo_str)
if hypo_i == 2:
output_texts_3.append(detok_hypo_str)
print(f'Top {hypo_i+1} prediction score: {score}')
# update running id_ counter
start_id += len(inputs)
# 要約結果を1〜3番目までそれぞれ書き出し
with open(f"{SAVE_MODEL_NAME}_test_tgt.txt", 'w') as f:
count = 0
for d in output_texts:
count+=1
f.write("%s\n" % d)
with open(f"{SAVE_MODEL_NAME}_test_tgt_2.txt", 'w') as f:
count = 0
for d in output_texts_2:
count+=1
f.write("%s\n" % d)
with open(f"{SAVE_MODEL_NAME}_test_tgt_3.txt", 'w') as f:
count = 0
for d in output_texts_3:
count+=1
f.write("%s\n" % d)
要約用スクリプトの実行
SAVE_MODEL_NAME = "livedoor_5_eopch"
MODEL_NAME = "models/" + SAVE_MODEL_NAME
def cli_main():
parser = options.get_interactive_generation_parser()
parser.add_argument('--bpe_model', default='', required=True)
parser.add_argument('--bpe_dict', default='', required=True)
bpe_model = MODEL_NAME + "/sp.model"
bpe_dict = MODEL_NAME + "/dict.txt"
datasets_dir = MODEL_NAME + "/datasets"
tuning_model = MODEL_NAME + "/save/checkpoint_best.pt"
input_args = [
datasets_dir,
"--path", tuning_model,
"--task", "translation_from_pretrained_bart",
"--max-sentences", "1",
"--bpe_model", bpe_model,
"--bpe_dict", bpe_dict,
"--nbest", "3",
"--skip-invalid-size-inputs-valid-test"
]
args = options.parse_args_and_arch(parser,input_args)
distributed_utils.call_main(args, main)
cli_main()
要約結果
以下、要約結果の一例です。
要約前 (1504文字)
最近、ネット上で「1歳児の74%がスマホを利用している」という統計結果が報じらました。そのアンケート調査によると、0歳児では24%なのが1歳で急増、2歳で85%となり、その後は90%前後で変化なく推移するようです。利用の内容は、お気に入りのキャラの動画、ゲーム、知育アプリなどが主なものです。スマホの登場で、赤ちゃんが言葉を憶える前から、PCの操作法を学び始める時代となりました。画期的な進歩です。しかし、同時に一抹の不安を抱かれることでしょう。今日、ネット依存が個人を蝕んでいるほか、企業もIT中毒により創造性を失いつつあると警鐘が鳴らされています。乳幼児の心の発達の上で、特に1歳前後というと、離乳期にあたります。また、運動能力が発達して自立歩行を始める時期であり、認知能力も高まるために外の世界に関心が開かれて、お母さんから物理的に離れ、いわば「冒険をしていく」ことが際立った特徴です。それまでは、お母さんに抱っこされて一心同体のように来たのが、この時期にはお母さんから離れだし、自分の興味の向くままに動き始めます。もっとも、歩行と言っても「よちよち歩き」ですから、お母さんと離れることには強い不安も心の底の方では感じているわけです。離れては、お母さんのいないことに気づき、お母さんの元に戻るという行動を繰り返します。これが3歳くらいになると、お母さんが目の前にいなくても安心して、ある程度一人で過ごせるようになってくるとされています。お母さん離れをして一人で平気でいられるようになるのは、子どもが想像力を働かせて、「目の前にいなくても、本当にいなくなったわけではない、お母さんはちゃんと傍にいてくれている」とわかっているからでしょう。また、逆に「お母さんが目の前にいない」からこそ、想像力が発達してくると言えます。私が心配しているのは、赤ちゃんの「おもり」のためにスマホを持たせるような場合です。1歳前後のこの時期、子どもは外の世界をいわば冒険し始めるのに、ゲームで慰めてくれるアプリは「アメ玉」のようなもので、安全なまま強い満足感を起こす刺激を与えてくれます。一方、外の世界の冒険では。万が一のことがあれば、自分で身を守らなければなりません。しかも二次元ではなく、五官で体験する生きた世界なのです。アプリの世界は刺激的で面白すぎるので、放っておくと「食べ過ぎ」になってしまうでしょう。また、欲しいときに欲しいだけ目の前に映像が現れるので、全能感を刺激されやすいといえるでしょう。これらのことが、中毒性につながりやすい要因ではないかと思います。現実では、欲しいときに欲しいものだけが現れるものではありません。せっかくお母さん以外に魅力的なものを発見して、独り立ちしようとする時期に、アプリの映像に関心が貼り付いてしまったら、子どもは想像する必要がなくなり、心の発達が阻害されてしまうでしょう。ですから、この時期は、子どもが自分なりに一人で現実の世界と接することを最優先してあげるべきです。その上で、並行してアプリの想像の世界も親と一緒に少しずつ経験させれば良いのではないでしょうか。「おやつ」の与え過ぎにならないように。ネットによる人間関係も、現実の関係があったうえに相手のことを想像して成り立ってくるものです。ITの仮想空間を使いこなすためには、まずしっかりとした現実感覚に基づく想像力が不可欠です。そのために、やはり昔ながらの子育てのやり方は相変わらず役に立つと思います。この基礎に立ってIT学習をさせれば、子どもはその弊害のリスクを克服して、新たな進歩をもたらしてくれる。そう信じたいものです。
要約後 (117文字)
スマホの登場で、子どもが言葉を憶える前からPCの操作法を学び始める時代となった。同時に一抹の不安を抱かれることでしょう。ネット依存が個人を蝕んでいるほか、企業もIT中毒により創造性を失いつつあると警鐘が鳴らされている。
7. 定量的評価
評価用スクリプト
テスト用の要約後の正解データとして、test_correct_tgt.txt を用意します。
正解データとBARTによる要約結果データを比較することで定量的な評価を行います。
以下の評価用スクリプトでは、3種類の評価を行います。
from bert_score import score
from sumeval.metrics.bleu import BLEUCalculator
from sumeval.metrics.rouge import RougeCalculator
def calc_bert_score(cands, refs):
""" BERTスコアの算出
Args:
cands ([List[str]]): [比較元の文]
refs ([List[str]]): [比較対象の文]
Returns:
[(List[float], List[float], List[float])]: [(Precision, Recall, F1スコア)]
"""
Precision, Recall, F1 = score(cands, refs, lang="ja", verbose=True)
return Precision.numpy().tolist(), Recall.numpy().tolist(), F1.numpy().tolist()
def evaluate_all(predict_file, correct_file, model_name):
bert_score_p = []
bert_score_r = []
bert_score_f = []
rouge_1_score = []
rouge_2_score = []
rouge_L_score = []
rouge_BE_score = []
blue_scores = []
with open(predict_file) as f:
predicts = f.readlines()
with open(correct_file) as f:
corrects = f.readlines()
P, R, F1 = calc_bert_score(predicts, corrects)
for p,r, f1 in zip(P, R, F1):
bert_score_p.append(p)
bert_score_r.append(r)
bert_score_f.append(f1)
for i in range(len(predicts)):
rouge = RougeCalculator(lang="ja")
rouge_1 = rouge.rouge_n(
summary=predicts[i],
references=corrects[i],
n=1)
rouge_2 = rouge.rouge_n(
summary=predicts[i],
references=corrects[i],
n=2)
rouge_l = rouge.rouge_l(
summary=predicts[i],
references=corrects[i])
rouge_be = rouge.rouge_be(
summary=predicts[i],
references=corrects[i])
rouge_1_score.append(rouge_1)
rouge_2_score.append(rouge_2)
rouge_L_score.append(rouge_l)
rouge_BE_score.append(rouge_be)
for i in range(len(predicts)):
bleu_ja = BLEUCalculator(lang="ja")
blue_score = bleu_ja.bleu(predicts[i], corrects[i])
blue_scores.append(blue_score)
# 出力用のDataFrameを作成
df = pd.DataFrame(
list(zip(bert_score_p, bert_score_r, bert_score_f, rouge_1_score, rouge_2_score, rouge_L_score, rouge_BE_score, blue_scores)),
columns = ["bert_score_p", "bert_score_r","bert_score_f","rouge_1_score","rouge_2_score","rouge_L_score","rouge_BE_score","blue_scores"]
)
display(df)
# 出力
df.to_csv(f"score/{model_name}_score.csv", mode="w", encoding="utf-8")
評価用スクリプトの実行
model_name = "livedoor_5_eopch"
correct_file = "test_correct_tgt.txt"
predict_file = model_name + "_test_tgt.txt"
evaluate_all(predict_file, correct_file, model_name)
評価結果
result_df = pd.read_csv(f"score/{model_name}_score.csv", encoding="utf-8")
result_df.describe()
bert_score_p bert_score_r bert_score_f rouge_1_score rouge_2_score rouge_L_score rouge_BE_score blue_scores count 108.000000 108.000000 108.000000 108.000000 108.000000 108.000000 108.000000 108.000000 mean 0.681451 0.702063 0.691378 0.343431 0.106225 0.239164 0.068418 6.821437 std 0.044127 0.043215 0.041866 0.096086 0.098048 0.093328 0.101410 9.179560 min 0.592922 0.583417 0.598654 0.121212 0.000000 0.118812 0.000000 0.087738 25% 0.653563 0.675007 0.664031 0.287083 0.045352 0.173608 0.000000 0.437510 50% 0.673625 0.694615 0.683875 0.329654 0.077670 0.215129 0.000000 2.117097 75% 0.706842 0.729759 0.711437 0.393233 0.151365 0.286358 0.106725 10.098101 max 0.873333 0.850494 0.861762 0.717391 0.577778 0.695652 0.387097 48.448171
参考