業界・業務から探す
導入目的・課題から探す
データ・AIについて学ぶ
News
Hakkyについて
ウェビナーコラム
◆トップ【AI・機械学習】
プロセスの全体像前処理・特徴量生成Fine Tuning手法まとめ機械学習モデルの選び方モデル評価手法プロトタイピング探索的分析(EDA)
AI

執筆者:Handbook編集部

WhisperのFine-Tuningにあたるデータセットと設定の比較

概要

本記事ではWhisperを様々なデータセットと設定を用いてFine-Tuningして、その結果を比較します。

Whisperのモデルの概要

WhisperはEncoder-decoder TransformerSeq2Seqとも呼ばれ)のモデルを使用します。すなわち、モデルは音声のシークエンスを単語のシークエンスに変換します。

データセット

これから増やす予定ですが、最初以下の二つのデータセットでFine-Tuningします。

実装

環境設定

本記事では、分かりやすいコードになりますように、torch以外のライブラリーをできるだけ使用しないようにします。

Colabで実行する場合は以下で必要なPythonパッケージをインストールすることができます。

!pip install openai-whisper
!pip install evaluate
!pip install jiwer

Fine-Tuningのコード

以下のコードでモデルをFine-Tuningします。時間かかりすぎないように、Tinyモデルを使用します。また、簡単にコピーできるように、一ヵ所でコードをまとめました。

import os
import logging
import glob
import pandas as pd
import whisper
import evaluate
import torch
from torch import Tensor
from tqdm import tqdm
import random
from tqdm.contrib.logging import logging_redirect_tqdm
from sklearn.model_selection import train_test_split

LOG = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

debug:bool = False
device:str = "cuda" if torch.cuda.is_available() else "cpu"
epochs:int = 4
batch_size:int = 32
max_character_count:int = 25
learning_rate:float = 1e-5
decoder_only:bool = False
teacher_forcing_ratio:float = 1.0
n_vocab:int=51865
output_size:int = 384

#データセットの読み込み(JVS)
dataset_folder:str = "jvs"
transcripts_path_list:list[str] = list(glob.glob("jvs/jvs_ver1/*/*/transcripts_utf8.txt"))
print("JVSのファイル数", len(transcripts_path_list))

def get_audio_file_list(transcripts_path_list:list[str]) -> tuple[list[str], list[str]]:
    audio:list[str] = []
    sentence:list[str] = []
    for transcripts_path in tqdm(transcripts_path_list):
        current_folder:str = os.path.dirname(transcripts_path)
        audio_dir:str = os.path.join(current_folder, "wav24kHz16bit")

        with open(transcripts_path, "r", encoding="utf8") as f:
            text_list = f.readlines()
        for text in text_list:
            audio_id, text = text.replace("\n", "").split(":")
            if len(text) > max_character_count:
                continue
            audio_path:str = os.path.join(audio_dir, f"{audio_id}.wav")
            if not os.path.exists(audio_path):
                continue
            audio.append(audio_path)
            sentence.append(text)
    return audio, sentence
audio, sentence = get_audio_file_list(transcripts_path_list)

#データセットの読み込み(Common voice)
dataset_folder:str = "cv-corpus-12.0-2022-12-07/ja/"
clips_folder:str = os.path.join(dataset_folder, 'clips')
print("Common Voiceの音声ファイル数", len(glob.glob(clips_folder + "/*.mp3")))

clips_folder:str = os.path.join(dataset_folder, 'clips')
common_voice_train = pd.read_csv(os.path.join(dataset_folder, 'train.tsv'), sep="\t")
common_voice_train = common_voice_train[common_voice_train['sentence'].str.len() <= max_character_count]
common_voice_train['path'] = common_voice_train['path'].apply(lambda x: os.path.join(clips_folder, x))
common_voice_test = pd.read_csv(os.path.join(dataset_folder, 'test.tsv'), sep="\t")
common_voice_test = common_voice_test[common_voice_test['sentence'].str.len() <= max_character_count]
common_voice_test['path'] = common_voice_test['path'].apply(lambda x: os.path.join(clips_folder, x))

audio:list[str] = audio + common_voice_train['path'].to_list() + common_voice_test['path'].to_list()
sentence:list[str] = sentence + common_voice_train['sentence'].to_list() + common_voice_test['sentence'].to_list()

X_train, X_test, y_train, y_test = train_test_split(audio, sentence, random_state=6)
print("TRAIN AUDIO DATASET NUM: ", len(X_train))
print("EVAL AUDIO DATASET NUM: ", len(X_test))

#Tokenizerを準備
woptions = whisper.DecodingOptions(language="ja", without_timestamps=True)
tokenizer = whisper.tokenizer.get_tokenizer(True, language="ja", task=woptions.task)

#モデルを読み込む
model = whisper.load_model("tiny", device=device)

#結果を評価
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

#Fine-Tuningの精度を計算
def test_dataset(path:list[str], sentence:list[str], model) -> tuple[float, float, int]:
    transcribe_result = []
    transcribe_references = []
    transcribe_result_id = []
    transcribe_references_id = []
    for i, (path, sentence) in tqdm(enumerate(zip(path, sentence))):
        result = model.transcribe(path, language='ja')
        if len(result['segments']) > 1 or len(result['segments']) < 1:
            result_tokens_string = ""
            result['text'] = ""
        else:
            result_tokens_string = " ".join(str(token) for token in result['segments'][0]['tokens'][1:-1])
        sentence_tokens = tokenizer.encode(sentence)
        sentence_tokens_string = " ".join(str(token) for token in sentence_tokens)

        transcribe_result_id.append(result_tokens_string)
        transcribe_references_id.append(sentence_tokens_string)
        transcribe_result.append(result['text'])
        transcribe_references.append(sentence)
    if len(transcribe_result) <= 0:
        return 0.0, 0.0, 0
    wer = wer_metric.compute(predictions=transcribe_result_id, references=transcribe_references_id)
    cer = cer_metric.compute(predictions=transcribe_result, references=transcribe_references)
    return wer, cer, len(transcribe_result)

def check_result(path:list[str], sentence:list[str], model) -> tuple[float, float]:
    wer, cer, count = test_dataset(path, sentence, model)
    print("現在のWER ({0:.2f}件)".format(count), wer)
    print("現在のCER ({0:.2f}件)".format(count), cer)
    return wer, cer

print("学習セットを評価")
check_result(X_train, y_train, model)
print("テストセットを評価")
check_result(X_test, y_test, model)

#モデルをFine-Tuning
start_tags:Tensor = torch.Tensor([*tokenizer.sot_sequence_including_notimestamps])
start_tags_len:int = len(start_tags)
end_tags:Tensor = torch.Tensor([tokenizer.eot])
end_tags_len:int = len(end_tags)
output_start_tags:Tensor = start_tags[1:]
output_start_tags_len:int = len(output_start_tags)
loss_fn = torch.nn.CrossEntropyLoss()
if not decoder_only:
    encoder_optimizer = torch.optim.Adam(model.encoder.parameters(), lr=learning_rate)
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=learning_rate)

def step_model(loss:float, current_count:int):
    loss = loss/current_count
    if debug:
        LOG.info("loss: {0:.2f}".format(loss))
    loss.backward()
    if not decoder_only:
        encoder_optimizer.step()
    decoder_optimizer.step()

def epoch():
    if not decoder_only:
        encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    current_count:int = 0
    loss:float = 0
    for audio_path, sentence in tqdm(zip(X_train, y_train)):
        with torch.no_grad():
            audio:Tensor = whisper.load_audio(audio_path)
            audio:Tensor = whisper.pad_or_trim(audio)
            audio:Tensor = whisper.log_mel_spectrogram(audio)
            encoded_sentence:Tensor = torch.Tensor(tokenizer.encode(sentence))
            correct_output:Tensor = torch.cat((output_start_tags, encoded_sentence, end_tags), 0).to(torch.long).to(device)
            if len(correct_output) > output_size:
                print("入力文は長すぎる", sentence)
                continue
        if decoder_only:
            with torch.no_grad():
                audio_features = model.encoder(torch.unsqueeze(audio.to(device), 0))
        else:
            audio_features:Tensor = model.encoder(torch.unsqueeze(audio.to(device), 0))

        force_teaching:bool = teacher_forcing_ratio > random.random()
        previous_in:Tensor = start_tags.to(device)

        for i in range(0, min(len(encoded_sentence)-start_tags_len-1, output_size-output_start_tags_len-1)):
            encoded_sentence_part = torch.unsqueeze(previous_in, 0).detach()
            encoded_sentence_part = encoded_sentence_part.long().to(device).detach()
            predicted_tokens = model.decoder(encoded_sentence_part, audio_features)
            single_predicted_tokens = predicted_tokens.reshape((-1, n_vocab))
            loss += loss_fn(single_predicted_tokens, correct_output[:len(single_predicted_tokens)].to(device))

            _, predicted_token = predicted_tokens[:i + output_start_tags_len].topk(1)
            predicted_token = predicted_token.squeeze()
            if force_teaching:
                previous_in = torch.cat((previous_in, correct_output[len(single_predicted_tokens)-1:len(single_predicted_tokens)]))
            else:
                previous_in = torch.cat((previous_in, predicted_token[-1:]))
            current_count += 1
            if current_count >= batch_size:
                step_model(loss, current_count)
                if not decoder_only:
                    encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                current_count = 0
                loss = 0
                audio_features = model.encoder(torch.unsqueeze(audio.to(device), 0))#新しい重みで特徴量を生成

    if current_count > 0:
        step_model(loss, current_count)

for e in range(0, epochs):
    with logging_redirect_tqdm():
        print("####Epoch {}".format(e))
        epoch()
        print("モデルを評価")
        print("学習セットを評価")
        check_result(X_train, y_train, model)
        print("テストセットを評価")
        check_result(X_test, y_test, model)

teacher_forcing_ratio変数

teacher_forcing_ratio変数で、入力に正しい単語か予測単語を加えて次入力を作成する確率を設定できます。正しい単語(forcing)のほうは早く学習できますが、不安定になる可能性があります。

結果

Common Voice

Common Voice(100件)
  • モデル:Tiny
  • Common Voiceから100件のサンプルを学習データに
  • Common Voiceから33件のサンプルをテストデータに
  • teacher_forcing_ratio=1.0
  • Decoderのみを学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.58|0.53|0.58|0.54| Epoch 1|0.49|0.41|0.60|0.53| Epoch 2|0.40|0.34|0.82|0.81| Epoch 3|0.36|0.31|0.60|0.57| Epoch 4|0.37|0.37|0.57|0.51|

学習データとテストデータの精度はどっちも改良されました。

Common Voice(100件、Encoder)
  • モデル:Tiny
  • Common Voiceから100件のサンプルを学習データに
  • Common Voiceから33件のサンプルをテストデータに
  • teacher_forcing_ratio=1.0
  • EncoderとDecoderを学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.58|0.53|0.58|0.54| Epoch 1|0.38|0.32|0.53|0.46| Epoch 2|0.26|0.21|0.52|0.44| Epoch 3|0.20|0.17|0.54|0.46| Epoch 4|0.16|0.14|0.56|0.49|

学習データの精度は大幅に改良され、テストデータの精度もよくなりました。

Common Voice

  • モデル:Tiny
  • 最大文字数:25
  • teacher_forcing_ratio=1.0
  • Decoderのみを学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.45|0.37|0.46|0.38| Epoch 1|0.35|0.26|0.46|0.34|

学習データの精度は改良され、テストデータの精度は少しだけ改良されました。

Common Voice(Base)

  • モデル:Base
  • 最大文字数:25
  • teacher_forcing_ratio=1.0
  • Decoderのみを学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.45|0.38|0.44|0.38| Epoch 1|0.33|0.26|0.49|0.41| Epoch 2|0.27|0.21|0.51|0.43|

Baseモデル(Tinyより大きいモデル)を使用すると、学習データの精度は改良されましたが、テストデータの精度は少し落ちました。

Common Voice(Base, Encoder)

  • モデル:Base
  • 最大文字数:25
  • teacher_forcing_ratio=1.0
  • EncoderとDecoderを学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.45|0.38|0.44|0.38| Epoch 1|0.43|0.36|0.49|0.40|

Baseモデルを使用し、Encoderも学習しても、学習データの精度は改良されましたが、テストデータの精度は少し落ちました。

JSV

JSV(100件)
  • モデル:Tiny
  • JSVから100件のサンプルを学習データに
  • JSVから33件のサンプルをテストデータに
  • teacher_forcing_ratio=1.0
  • Decoderのみを学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.58|0.53|0.53|0.53| Epoch 1|0.46|0.40|0.58|0.49| Epoch 2|0.38|0.32|0.74|0.70| Epoch 3|0.30|0.27|0.58|0.52| Epoch 4|0.24|0.22|0.64|0.62|

学習データの精度は大分改良され、テストデータの精度は落ちました。過学習が原因だと思われています。

JSV(100件、Encoder)
  • モデル:Tiny
  • JSVから100件のサンプルを学習データに
  • JSVから33件をテストデータに
  • teacher_forcing_ratio=1.0
  • EncoderとDecoderを学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.58|0.53|0.53|0.53| Epoch 1|0.29|0.25|0.41|0.34| Epoch 2|0.19|0.16|0.44|0.38| Epoch 3|0.14|0.11|0.52|0.47| Epoch 4|0.11|0.09|0.52|0.47|

学習データの精度は大幅に改良され、テストデータの精度は少しよくなりました。

JSV

  • モデル:Tiny
  • 最大文字数:25
  • teacher_forcing_ratio=1.0
  • EncoderとDecoderを学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.45|0.37|0.46|0.37| Epoch 1|0.35|0.26|0.45|0.32| Epoch 2|0.34|0.28|0.48|0.39|

学習データの精度はよくなりました。テストデータの精度はよくなったりしました。

JSV + Common Voice

  • モデル:Tiny
  • 最大文字数:25
  • teacher_forcing_ratio=1.0
  • Decoderのみ学習

||Train WER |Train CER | Test WER | Test CER --- | --- | ---|---|---| 学習前|0.54|0.47|0.53|0.47| Epoch 1|0.72|0.62|0.74|0.65|

学習データとテストデータの精度は落ちました。JSVとCommon Voiceのデータ(内容と書き方)が異なって、学習を行って、精度が落ちたと思われています。より膨大なデータセットの場合は様々な内容と書き方がカバーされ、問題にならないと考えれています。

終わりに

本記事ではWhisperを複数のデータセットと設定でFine-Tuningしました。

Encoderを学習したほうが学習データの精度がより上がりますが、テストデートに大きい影響はしません。また、teacher_forcing_ratioを変えてみましたが、基本的には精度に悪影響していました。

その上、複数のデータセットを合わせて、精度がよくならなかったですが、データセットの差異が原因だと思われています。

つまり、適切な設定を見つけれるのは困難です。これからも違うデータセットを試して、本記事を更新します。

2025年06月15日に最終更新
読み込み中...