業界・業務から探す
導入目的・課題から探す
データ・AIについて学ぶ
News
Hakkyについて
ウェビナーコラム
◆トップ【AI・機械学習】
プロセスの全体像前処理・特徴量生成Fine Tuning手法まとめ機械学習モデルの選び方モデル評価手法プロトタイピング探索的分析(EDA)
ドミナントカラー検出セグメンテーション技術の基礎と実装局所特徴量抽出Grad-CAMまとめ画像の二値化とその手法モルフォロジー演算とその手法【Vision AI】Painterの紹介pix2structの紹介
AI

執筆者:Handbook編集部

HRFormerのチュートリアル

はじめに

本記事では HRFormer の学習を行います。 今回は HRNet 同様、coco データセットによる学習済みモデルをcoco_tinyというデータセットでfinetuningを行い、実際に推論までやってみます。

動作環境

実行環境は Google Colaboratry を使用します。

必要なライブラリのインストール

下記コマンドで必要なmmpose等のパッケージをインストールします。

%pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
%pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html
%pip install mmdet
%rm -rf mmpose
!git clone https://github.com/open-mmlab/mmpose.git
%cd mmpose
%pip install -r requirements.txt
%pip install -e .

import torch, torchvision
import cv2
import json
import pprint
import json
import os.path as osp
from collections import OrderedDict
import tempfile
import numpy as np

import mmpose
import mmcv
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
from mmpose.apis import (inference_top_down_pose_model, init_pose_model, vis_pose_result, process_mmdet_results)
from mmdet.apis import inference_detector, init_detector
from mmpose.core.evaluation.top_down_eval import (keypoint_nme, keypoint_pck_accuracy)
from mmpose.datasets.builder import DATASETS
from mmpose.datasets.datasets.base import Kpt2dSviewRgbImgTopDownDataset
from mmpose.datasets import build_dataset
from mmpose.models import build_posenet
from mmpose.apis import train_model
from mmcv import Config
from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
                         vis_pose_result, process_mmdet_results)
from mmdet.apis import inference_detector, init_detector
from google.colab.patches import cv2_imshow

データセットの準備

今回はopenmmlabmmpose内にあるcoco_tinyというデータセットを学習と推論で使用します。

%mkdir data
%cd data
!wget https://download.openmmlab.com/mmpose/datasets/coco_tiny.tar
!tar -xf coco_tiny.tar
%cd ..

Dataset class の定義

学習・推論でデータをモデルに渡すための Dataset オブジェクトを定義します。 mmpose.datasets.datasets.base.Kpt2dSviewRgbImgTopDownDatasetを継承させて作成します。 説明するのは大変煩雑になるため、コピペで構いません。

@DATASETS.register_module()
class TopDownCOCOTinyDataset(Kpt2dSviewRgbImgTopDownDataset):

    def __init__(self,
                 ann_file,
                 img_prefix,
                 data_cfg,
                 pipeline,
                 dataset_info=None,
                 test_mode=False):
      super().__init__(
          ann_file,
          img_prefix,
          data_cfg,
          pipeline,
          dataset_info,
          coco_style=False,
          test_mode=test_mode)

      # flip_pairs, upper_body_ids and lower_body_ids will be used
      # in some data augmentations like random flip
      self.ann_info['flip_pairs'] = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
                                      [11, 12], [13, 14], [15, 16]]
      self.ann_info['upper_body_ids'] = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
      self.ann_info['lower_body_ids'] = (11, 12, 13, 14, 15, 16)

      self.ann_info['joint_weights'] = None
      self.ann_info['use_different_joint_weights'] = False

      self.dataset_name = 'coco_tiny'
      self.db = self._get_db()

    def _get_db(self):
      with open(self.ann_file) as f:
        anns = json.load(f)

      db = []
      for idx, ann in enumerate(anns):
        # get image path
        image_file = osp.join(self.img_prefix, ann['image_file'])
        # get bbox
        bbox = ann['bbox']
        # get keypoints
        keypoints = np.array(
            ann['keypoints'], dtype=np.float32).reshape(-1, 3)
        num_joints = keypoints.shape[0]
        joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
        joints_3d[:, :2] = keypoints[:, :2]
        joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32)
        joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])

        sample = {
            'image_file': image_file,
            'bbox': bbox,
            'rotation': 0,
            'joints_3d': joints_3d,
            'joints_3d_visible': joints_3d_visible,
            'bbox_score': 1,
            'bbox_id': idx,
        }
        db.append(sample)

      return db

    def evaluate(self, results, res_folder=None, metric='PCK', **kwargs):
      """Evaluate keypoint detection results. The pose prediction results will
      be saved in `${res_folder}/result_keypoints.json`.

      Note:
      batch_size: N
      num_keypoints: K
      heatmap height: H
      heatmap width: W

      Args:
      results (list(preds, boxes, image_path, output_heatmap))
          :preds (np.ndarray[N,K,3]): The first two dimensions are
              coordinates, score is the third dimension of the array.
          :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
              , scale[1],area, score]
          :image_paths (list[str]): For example, ['Test/source/0.jpg']
          :output_heatmap (np.ndarray[N, K, H, W]): model outputs.

      res_folder (str, optional): The folder to save the testing
              results. If not specified, a temp folder will be created.
              Default: None.
      metric (str | list[str]): Metric to be performed.
          Options: 'PCK', 'NME'.

      Returns:
          dict: Evaluation results for evaluation metric.
      """
      metrics = metric if isinstance(metric, list) else [metric]
      allowed_metrics = ['PCK', 'NME']
      for metric in metrics:
        if metric not in allowed_metrics:
          raise KeyError(f'metric {metric} is not supported')

      if res_folder is not None:
        tmp_folder = None
        res_file = osp.join(res_folder, 'result_keypoints.json')
      else:
        tmp_folder = tempfile.TemporaryDirectory()
        res_file = osp.join(tmp_folder.name, 'result_keypoints.json')

      kpts = []
      for result in results:
        preds = result['preds']
        boxes = result['boxes']
        image_paths = result['image_paths']
        bbox_ids = result['bbox_ids']

        batch_size = len(image_paths)
        for i in range(batch_size):
          kpts.append({
              'keypoints': preds[i].tolist(),
              'center': boxes[i][0:2].tolist(),
              'scale': boxes[i][2:4].tolist(),
              'area': float(boxes[i][4]),
              'score': float(boxes[i][5]),
              'bbox_id': bbox_ids[i]
          })
      kpts = self._sort_and_unique_bboxes(kpts)

      self._write_keypoint_results(kpts, res_file)
      info_str = self._report_metric(res_file, metrics)
      name_value = OrderedDict(info_str)

      if tmp_folder is not None:
        tmp_folder.cleanup()

      return name_value

    def _report_metric(self, res_file, metrics, pck_thr=0.3):
      """Keypoint evaluation.

      Args:
      res_file (str): Json file stored prediction results.
      metrics (str | list[str]): Metric to be performed.
          Options: 'PCK', 'NME'.
      pck_thr (float): PCK threshold, default: 0.3.

      Returns:
      dict: Evaluation results for evaluation metric.
      """
      info_str = []

      with open(res_file, 'r') as fin:
        preds = json.load(fin)
      assert len(preds) == len(self.db)

      outputs = []
      gts = []
      masks = []

      for pred, item in zip(preds, self.db):
        outputs.append(np.array(pred['keypoints'])[:, :-1])
        gts.append(np.array(item['joints_3d'])[:, :-1])
        masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0)

      outputs = np.array(outputs)
      gts = np.array(gts)
      masks = np.array(masks)

      normalize_factor = self._get_normalize_factor(gts)

      if 'PCK' in metrics:
        _, pck, _ = keypoint_pck_accuracy(outputs, gts, masks, pck_thr,
                                          normalize_factor)
        info_str.append(('PCK', pck))

      if 'NME' in metrics:
        info_str.append(
            ('NME', keypoint_nme(outputs, gts, masks, normalize_factor)))

      return info_str

    @staticmethod
    def _write_keypoint_results(keypoints, res_file):
      """Write results into a json file."""

      with open(res_file, 'w') as f:
        json.dump(keypoints, f, sort_keys=True, indent=4)

    @staticmethod
    def _sort_and_unique_bboxes(kpts, key='bbox_id'):
      """sort kpts and remove the repeated ones."""
      kpts = sorted(kpts, key=lambda x: x[key])
      num = len(kpts)
      for i in range(num - 1, 0, -1):
        if kpts[i][key] == kpts[i - 1][key]:
          del kpts[i]

      return kpts

    @staticmethod
    def _get_normalize_factor(gts):
      """Get inter-ocular distance as the normalize factor, measured as the
      Euclidean distance between the outer corners of the eyes.

      Args:
          gts (np.ndarray[N, K, 2]): Groundtruth keypoint location.

      Return:
          np.ndarray[N, 2]: normalized factor
      """

      interocular = np.linalg.norm(
          gts[:, 0, :] - gts[:, 1, :], axis=1, keepdims=True)
      return np.tile(interocular, [1, 2])

モデルの config を定義

モデルの学習に必要な各パラメーターを定義します。 今回は coco データセットでの学習済モデルをfinetuningする形で学習を行うため、 そのモデルのconfigをベースにしてdownloadしたcoco_tinyデータセットに合わせて修正を加えていきます。 また、HRFormerではBatchNormalization層にSyncBNというものを使用していますが、 SingleGPU での実行の場合、現状エラーとなってしまうためSyncBNの設定を削除しています。

cfg = Config.fromfile(
    './configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/hrformer_small_coco_384x288.py'
)

cfg.data_root = 'data/coco_tiny'
cfg.work_dir = 'work_dirs/hrformer_small_coco_384x288'
cfg.gpu_ids = range(1)
cfg.seed = 0
cfg.norm_cfg = None
del cfg.model.backbone.norm_cfg

cfg.log_config.interval = 1

cfg.evaluation.interval = 10
cfg.evaluation.metric = 'PCK'
cfg.evaluation.save_best = 'PCK'

lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=10,
    warmup_ratio=0.001,
    step=[17, 35])
cfg.total_epochs = 120
cfg.data.samples_per_gpu = 16
cfg.data.val_dataloader = dict(samples_per_gpu=16)
cfg.data.test_dataloader = dict(samples_per_gpu=16)

cfg.data.train.type = 'TopDownCOCOTinyDataset'
cfg.data.train.ann_file = f'{cfg.data_root}/train.json'
cfg.data.train.img_prefix = f'{cfg.data_root}/images/'

cfg.data.val.type = 'TopDownCOCOTinyDataset'
cfg.data.val.ann_file = f'{cfg.data_root}/val.json'
cfg.data.val.img_prefix = f'{cfg.data_root}/images/'

cfg.data.test.type = 'TopDownCOCOTinyDataset'
cfg.data.test.ann_file = f'{cfg.data_root}/val.json'
cfg.data.test.img_prefix = f'{cfg.data_root}/images/'

HRFormer の学習

keypoint detectionではまずobject detectionを行い、それで得られたbounding boxを用いてkeypoint detectionが行われます。 もちろんbounding box情報がすでに存在する場合はobject detectionのステップはスキップできます。 今回はkeypoint detectionのチュートリアルであるため、HRFormer部分のみ学習を行います。 以下のコードで上記で準備したconfigdatasetを使用してモデルの学習を行います。

datasets = [build_dataset(cfg.data.train)]
model = build_posenet(cfg.model)
mmcv.mkdir_or_exist(cfg.work_dir)
train_model(
    model, datasets, cfg, distributed=False, validate=True, meta=dict())

学習結果は以下のようになりました。

なおBest Epoch90でした。

HRFormer の推論

上記で学習して作成したモデルを用いて推論を行うことができます。 下記コマンドで実行できます。

pose_checkpoint = 'work_dirs/hrformer_small_coco_384x288/best_PCK_epoch_90.pth'
det_config = 'demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'
det_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

pose_model = init_pose_model(cfg, pose_checkpoint)
det_model = init_detector(det_config, det_checkpoint)

img = 'tests/data/coco/000000000785.jpg'

mmdet_results = inference_detector(det_model, img)
person_results = process_mmdet_results(mmdet_results, cat_id=1)
pose_results, returned_outputs = inference_top_down_pose_model(
    pose_model,
    img,
    person_results,
    bbox_thr=0.3,
    format='xyxy',
    dataset='TopDownCocoDataset')
vis_result = vis_pose_result(
    pose_model,
    img,
    pose_results,
    kpt_score_thr=0.,
    dataset='TopDownCocoDataset',
    show=False)
vis_result = cv2.resize(vis_result, dsize=None, fx=0.5, fy=0.5)
cv2_imshow(vis_result)

推論結果は以下のようになりました。

学習データや epoch 数が少ないためかHRNetの結果よりは精度が悪そうです。

まとめ

今回はHRFormerによる人間のkeypoint detectionのチュートリアルを行いました。 画像タスクでも人気のTransformerkeypoint detectionにも取り入れたものでした。 今回のチュートリアルではHRNetよりも精度が低かったですが、十分なデータと学習時間を与えればHRNetよりも精度が上がるかもしれません。

参考

2025年06月13日に最終更新
読み込み中...