業界・業務から探す
導入目的・課題から探す
データ・AIについて学ぶ
News
Hakkyについて
ウェビナーコラム
◆トップ【AI・機械学習】
プロセスの全体像前処理・特徴量生成Fine Tuning手法まとめ機械学習モデルの選び方モデル評価手法プロトタイピング探索的分析(EDA)
AI

執筆者:Handbook編集部

WhisperのFine-Tuningデモ

概要

本記事ではOpenAIのWhisperを使ったFine Tuningのデモをサンプルコードと共に紹介します。本デモはGoogle Colabでの動作を想定しています。

ライブラリのインストール

最初に本デモで必要なライブラリのインストールを行います。本記事で利用するライブラリは以下のとおりです。

%%capture
! pip install git+https://github.com/openai/whisper.git
! pip install jiwer

! pip install pyopenjtalk==0.3.0
! pip install pytorch-lightning==1.7.7
! pip install -qqq evaluate==0.2.2
! pip install --upgrade --no-cache-dir gdown

初期設定

利用するモジュールのインポートをします。

import IPython.display
from pathlib import Path

import os
import numpy as np

try:
    import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:
    pass

import torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as at

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from tqdm.notebook import tqdm
import pyopenjtalk
import evaluate

from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)

のちに利用する変数もここで定義しておきます。

DATASET_DIR = "/content/jvs/jvs_ver1"
SAMPLE_RATE = 16000
BATCH_SIZE = 2
TRAIN_RATE = 0.8

AUDIO_MAX_LENGTH = 480000
TEXT_MAX_LENGTH = 120
SEED = 3407
DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
seed_everything(SEED, workers=True)

のちに利用する音声読み込み用のヘルパースクリプトを作成します。

def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = at.Resample(sr, sample_rate)(waveform)
    return waveform

JVSデータセットのダウンロード

JVSは日本語テキストと音声データからなる音声コーパスです。本デモではJVSを用いてFine Tuningの実験を行います。

JVSの概要については以下のとおりです。

  • 100人のプロフェッショナル話者(声優・俳優など)によって収録された音声データ
  • 各話者について、話者間で共通する読み上げ音声 100 発話、話者間で全く異なる読み上げ音声 30 発話、ささやき声 10 発話、裏声 10 発話が含まれる
  • 高音質(スタジオ収録)・高サンプリングレート(24 kHz)・多数の (30 時間) 音声ファイル
  • 便利なタグが付与されている (例: 性別,F0レンジ,話者類似度,継続長,音素アライメント (自動抽出)))

データはGoogle Driveからダウンロード可能であり、本デモではgdownを用いてDownloadしてくることにします。

%%capture
import gdown
gdown.download('https://drive.google.com/u/0/uc?id=19oAw8wWn3Y7z6CKChRdAyGOB9yupL_Xt', 'jvs.zip', quiet=False)
!unzip jvs.zip -d ./jvs
# エラーなく結果が出力されていればOK
!ls jvs/jvs_ver1
# 発話サンプル
import IPython.display
IPython.display.Audio(filename="/content/jvs/jvs_ver1/jvs001/nonpara30/wav24kHz16bit/BASIC5000_0025.wav", rate=SAMPLE_RATE)

データセットの前処理

以下ではデータセットの情報を取得し、Train/Valの分割を行います>

dataset_dir = Path(DATASET_DIR)
transcripts_path_list = list(dataset_dir.glob("*/*/transcripts_utf8.txt"))
print(len(transcripts_path_list))
def get_audio_file_list(transcripts_path_list, text_max_length=120, audio_max_sample_length=480000, sample_rate=16000):
    audio_transcript_pair_list = []
    for transcripts_path in tqdm(transcripts_path_list):
        # audioファイルのディレクトリ確認
        audio_dir = transcripts_path.parent / "wav24kHz16bit"
        if not audio_dir.exists():
            print(f"{audio_dir}は存在しません。")
            continue

        # 翻訳テキストからAudioIdとテキストを取得
        with open(transcripts_path, "r") as f:
            text_list = f.readlines()
        for text in text_list:
            audio_id, text = text.replace("\n", "").split(":")
            #print(audio_id, text)

            audio_path = audio_dir / f"{audio_id}.wav"
            if audio_path.exists():
                # データのチェック
                audio = load_wave(audio_path, sample_rate=sample_rate)[0]
                if len(text) > text_max_length or len(audio) > audio_max_sample_length:
                    print(len(text), len(audio))
                    continue
                audio_transcript_pair_list.append((audio_id, str(audio_path), text))
    return audio_transcript_pair_list

train_num = int(len(transcripts_path_list) * TRAIN_RATE)
train_transcripts_path_list, eval_transcripts_path_list = transcripts_path_list[:train_num], transcripts_path_list[train_num:]
train_audio_transcript_pair_list = get_audio_file_list(train_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
eval_audio_transcript_pair_list = get_audio_file_list(eval_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)

print("TRAIN AUDIO DATASET NUM: ", len(train_audio_transcript_pair_list))
print("EVAL AUDIO DATASET NUM: ", len(eval_audio_transcript_pair_list))

# OUTPUTS:
# TRAIN AUDIO DATASET NUM:  11992
# EVAL AUDIO DATASET NUM:  2990

torchデータローダの作成

以下のコードでは、torchデータセットとデータローダの作成を行います。

# カナへの変換
def text_kana_convert(text):
    text = pyopenjtalk.g2p(text, kana=True)
    return text
print(text_kana_convert("こんにちは、私の名前は、田中一郎です。"))
woptions = whisper.DecodingOptions(language="ja", without_timestamps=True)
wmodel = whisper.load_model("base")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="ja", task=woptions.task)
# JVSデータセットのtorchデータセットクラス
class JvsSpeechDataset(torch.utils.data.Dataset):
    def __init__(self, audio_info_list, tokenizer, sample_rate) -> None:
        super().__init__()

        self.audio_info_list = audio_info_list
        self.sample_rate = sample_rate
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.audio_info_list)

    def __getitem__(self, id):
        audio_id, audio_path, text = self.audio_info_list[id]

        # Preprocessing
        audio = load_wave(audio_path, sample_rate=self.sample_rate)
        audio = whisper.pad_or_trim(audio.flatten())
        mel = whisper.log_mel_spectrogram(audio) # Log Mel Spec

        text = text_kana_convert(text)d
        text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
        labels = text[1:] + [self.tokenizer.eot]

        return {
            "input_ids": mel,
            "labels": labels,
            "dec_input_ids": text
        }
# collate_fnで利用するクラスの作成

class WhisperDataCollatorWhithPadding:
    def __call__(sefl, features):
        input_ids, labels, dec_input_ids = [], [], []
        for f in features:
            input_ids.append(f["input_ids"])
            labels.append(f["labels"])
            dec_input_ids.append(f["dec_input_ids"])

        input_ids = torch.concat([input_id[None, :] for input_id in input_ids])

        label_lengths = [len(lab) for lab in labels]
        dec_input_ids_length = [len(e) for e in dec_input_ids]
        max_label_len = max(label_lengths+dec_input_ids_length)

        labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
        dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id

        batch = {
            "labels": labels,
            "dec_input_ids": dec_input_ids
        }

        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
        batch["input_ids"] = input_ids
        return batch

以下でデータローダの動作確認を行います。

dataset = JvsSpeechDataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

for b in loader:
    print(b["labels"].shape)
    print(b["input_ids"].shape)
    print(b["dec_input_ids"].shape)

    for token, dec in zip(b["labels"], b["dec_input_ids"]):
        token[token == -100] = wtokenizer.eot
        text = wtokenizer.decode(token, skip_special_tokens=False)
        print(text)

        dec[dec == -100] = wtokenizer.eot
        text = wtokenizer.decode(dec, skip_special_tokens=False)
        print(text)
    break

# OUTPUTS:
# torch.Size([2, 59])
# torch.Size([2, 80, 3000])
# torch.Size([2, 59])
# <|ja|><|transcribe|><|notimestamps|>マタ、トージノヨーニ、ゴダイミョーオートヨバレル、シュヨーナミョーオーノチューオーニハイサレルコトモオーイ。<|endoftext|>
# <|startoftranscript|><|ja|><|transcribe|><|notimestamps|>マタ、トージノヨーニ、ゴダイミョーオートヨバレル、シュヨーナミョーオーノチューオーニハイサレルコトモオーイ。
# <|ja|><|transcribe|><|notimestamps|>ニューイングランドフーワ、ギューニューヲベーストシタ、シロイクリームスープデアリ、ボストンクラムチャウダートモヨバレル。<|endoftext|>
# <|startoftranscript|><|ja|><|transcribe|><|notimestamps|>ニューイングランドフーワ、ギューニューヲベーストシタ、シロイクリームスープデアリ、ボストンクラムチャウダートモヨバレル。
with torch.no_grad():
    audio_features = wmodel.encoder(b["input_ids"].cuda())
    input_ids = b["input_ids"]
    labels = b["labels"].long()
    dec_input_ids = b["dec_input_ids"].long()


    audio_features = wmodel.encoder(input_ids.cuda())
    print(dec_input_ids)
    print(input_ids.shape, dec_input_ids.shape, audio_features.shape)
    print(audio_features.shape)
    print()
out = wmodel.decoder(dec_input_ids.cuda(), audio_features)

print(out.shape)
print(out.view(-1, out.size(-1)).shape)
print(b["labels"].view(-1).shape)

"""
OUTPUTS:
 tensor([[50258, 50266, 50359, 50363, 13258, 12144,  1231,  7588, 44143, 34501,
            1047,   101, 15266,   233,  1231, 39780, 28651,  8040, 23196, 28579,
            3384, 18743, 38551,  1047,   101, 18593, 16680,  9405,  1231, 11054,
            26167,  1047,   101, 15266,   232, 23196, 28579,  3384, 18743, 15266,
            236, 17794, 26167,  3384, 18743, 15266,   233, 15927,  8040, 23607,
            16680,  9405, 18066,  7588, 29183, 18743,  3384,  8040,  1543],
        [50258, 50266, 50359, 50363, 34737, 26167,  3384,  8040, 43017, 11353,
            46271, 17320, 15266,   107,  1231,   824,   106, 26167, 15266,   233,
            26167, 15266,   110, 45290,  3384, 40498, 11054, 12144,  1231, 11054,
            17164,  8040, 10825, 12376, 44265,  9550, 15266,   245, 31327, 12817,
            12376,  1231, 37626, 40498,  4824, 10825, 11353, 32026, 31771, 20745,
            28651, 38551, 29183,  1047,   101, 18593, 16680,  9405,  1543]])
torch.Size([2, 80, 3000]) torch.Size([2, 59]) torch.Size([2, 1500, 512])
torch.Size([2, 1500, 512])
"""
tokens = torch.argmax(out, dim=2)
for token in tokens:
    token[token == -100] = wtokenizer.eot
    text = wtokenizer.decode(token, skip_special_tokens=True)
    print(text)

torch-lightningのTrainerの作成

以下のコードではtorch-lightningのTrainerを作成します。

# ハイパーパラメータの設定
class Config:
    learning_rate = 0.0005
    weight_decay = 0.01
    adam_epsilon = 1e-8
    warmup_steps = 2
    batch_size = 16
    num_worker = 2
    num_train_epochs = 10
    gradient_accumulation_steps = 1
    sample_rate = SAMPLE_RATE

# カスタムTrainerクラス
class WhisperModelModule(LightningModule):
    def __init__(self, cfg:Config, model_name="base", lang="ja", train_dataset=[], eval_dataset=[]) -> None:
        super().__init__()
        self.options = whisper.DecodingOptions(language=lang, without_timestamps=True)
        self.model = whisper.load_model(model_name)
        self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="ja", task=self.options.task)

        # only decoder training
        for p in self.model.encoder.parameters():
            p.requires_grad = False

        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
        self.metrics_wer = evaluate.load("wer")
        self.metrics_cer = evaluate.load("cer")

        self.cfg = cfg
        self.__train_dataset = train_dataset
        self.__eval_dataset = eval_dataset

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()

        with torch.no_grad():
            audio_features = self.model.encoder(input_ids)

        out = self.model.decoder(dec_input_ids, audio_features)
        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
        self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()


        audio_features = self.model.encoder(input_ids)
        out = self.model.decoder(dec_input_ids, audio_features)

        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))

        out[out == -100] = self.tokenizer.eot
        labels[labels == -100] = self.tokenizer.eot

        o_list, l_list = [], []
        for o, l in zip(out, labels):
            o = torch.argmax(o, dim=1)
            o_list.append(self.tokenizer.decode(o, skip_special_tokens=True))
            l_list.append(self.tokenizer.decode(l, skip_special_tokens=True))
        cer = self.metrics_cer.compute(references=l_list, predictions=o_list)
        wer = self.metrics_wer.compute(references=l_list, predictions=o_list)

        self.log("val/loss", loss, on_step=True, prog_bar=True, logger=True)
        self.log("val/cer", cer, on_step=True, prog_bar=True, logger=True)
        self.log("val/wer", wer, on_step=True, prog_bar=True, logger=True)

        return {
            "cer": cer,
            "wer": wer,
            "loss": loss
        }

    def configure_optimizers(self):
        """オプティマイザーとスケジューラーを作成する"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters()
                            if not any(nd in n for nd in no_decay)],
                "weight_decay": self.cfg.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters()
                            if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.cfg.learning_rate,
                          eps=self.cfg.adam_epsilon)
        self.optimizer = optimizer

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.cfg.warmup_steps,
            num_training_steps=self.t_total
        )
        self.scheduler = scheduler

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

    def setup(self, stage=None):
        """初期設定(データセットの読み込み)"""

        if stage == 'fit' or stage is None:
            self.t_total = (
                (len(self.__train_dataset) // (self.cfg.batch_size))
                // self.cfg.gradient_accumulation_steps
                * float(self.cfg.num_train_epochs)
            )

    def train_dataloader(self):
        """訓練データローダーを作成する"""
        dataset = JvsSpeechDataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset,
                          batch_size=self.cfg.batch_size,
                          drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,
                          collate_fn=WhisperDataCollatorWhithPadding()
                          )

    def val_dataloader(self):
        """バリデーションデータローダーを作成する"""
        dataset = JvsSpeechDataset(self.__eval_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset,
                          batch_size=self.cfg.batch_size,
                          num_workers=self.cfg.num_worker,
                          collate_fn=WhisperDataCollatorWhithPadding()
                          )


TensorBoardの設定

学習ログを表示するためのTensorBoardの設定を行います。

%load_ext tensorboard
%tensorboard --logdir /content/logs

モデル学習

以上でWhisper学習に必要な準備が整ったため、ここからは実際にモデルの学習を行っていきます。

log_output_dir = "/content/logs"
check_output_dir = "/content/artifacts"

train_name = "whisper"
train_id = "00001"

model_name = "base"
lang = "ja"
cfg = Config()

Path(log_output_dir).mkdir(exist_ok=True)
Path(check_output_dir).mkdir(exist_ok=True)

tflogger = TensorBoardLogger(
    save_dir=log_output_dir,
    name=train_name,
    version=train_id
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{check_output_dir}/checkpoint",
    filename="checkpoint-{epoch:04d}",
    save_top_k=-1 # all model save
)

callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")]
model = WhisperModelModule(cfg, model_name, lang, train_audio_transcript_pair_list, eval_audio_transcript_pair_list)

trainer = Trainer(
    precision=16,
    accelerator=DEVICE,
    max_epochs=cfg.num_train_epochs,
    accumulate_grad_batches=cfg.gradient_accumulation_steps,
    logger=tflogger,
    callbacks=callback_list
)

trainer.fit(model)

評価

以下のコードでは学習したモデルの評価を行います。今回はCER(Character Error Rate)で評価を行います。

筆者環境ではCERが0.014とCERがかなり低い結果となりました。今回の場合かなに変換していたりとタスクが簡単であったため今回の結果を一概に評価することはできませんが、評価結果や学習曲線からFine Tuning自体はうまくいっているように感じます。

checkpoint_path = "/content/artifacts/checkpoint/checkpoint-epoch=0007.ckpt"
state_dict = torch.load(checkpoint_path)
print(state_dict.keys())
state_dict = state_dict['state_dict']
whisper_model = WhisperModelModule(cfg)
whisper_model.load_state_dict(state_dict)
woptions = whisper.DecodingOptions(language="ja", without_timestamps=True)
dataset = JvsSpeechDataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE) # Eval dataset
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

refs = [] # GT
res = [] # Predictions

for b in tqdm(loader):
    input_ids = b["input_ids"].half().cuda()
    labels = b["labels"].long().cuda()
    with torch.no_grad():
        results = whisper_model.model.decode(input_ids, woptions)
        for r in results:
            res.append(r.text)

        for l in labels:
            l[l == -100] = wtokenizer.eot
            ref = wtokenizer.decode(l, skip_special_tokens=True)
            refs.append(ref)
cer_metrics = evaluate.load("cer")
cer_metrics.compute(references=refs, predictions=res)
# OUTPUTS:
# 0.014300600846950314
# 推論結果の表示
for k, v in zip(refs, res):
    print("-"*10)
    print(k)
    print(v)

まとめ

本記事ではWhisperのFine Tuningについて紹介しました。Hakkyでは今後もWhisper活用のため研究開発を行っていきます。

参考

2025年06月15日に最終更新
読み込み中...