timm による静止画像分類(ソースコードと実行結果)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install -U "timm>=1.0" pillow opencv-python requests

timm forward_intermediates 画像分類プログラム

概要

入力画像を解析してImageNet-1kデータセットの1000クラスから物体カテゴリを特定する。Vision Transformerの中間層から特徴マップを抽出し、モデルの判断根拠を可視化する。

主要技術

参考文献

ソースコード


# timm forward_intermediates ViT中間特徴ヒートマップ分類プログラム
# 特徴技術名: forward_intermediates() API
# 出典: Wightman, R. (2025). PyTorch Image Models (timm). https://github.com/huggingface/pytorch-image-models
# 概要: Vision Transformerの中間層特徴をforward_intermediates()で取得し、ヒートマップ可視化と画像分類を行う
# 学習済みモデル詳細:
#   ViT Tiny: ImageNet-21k (1400万画像、21,843クラス) で事前学習、ImageNet-1k (100万画像、1,000クラス) でファインチューニング
#   ViT Base: ImageNet-21k で事前学習、ImageNet-1k でaugreg2技術によりファインチューニング
#   ViT Large: ImageNet-21k で事前学習、ImageNet-1k でaugregファインチューニング
#   augreg技術: 「How to train your ViT?」論文 (Steiner et al., 2021) のAugmentation & Regularization手法
#   特徴: パッチベース画像処理、Transformer attention機構、転移学習対応、URL: HuggingFace Hub経由
# 方式設計:
#   関連利用技術: OpenCV(画像入出力と表示)、PyTorch(テンソル操作)、Pillow(画像前処理と日本語テキスト表示)、tkinter(ファイル選択)
#   入力と出力: 入力: 1つの静止画像,カメラ(ユーザは「0:画像ファイル,1:カメラ,2:サンプル画像」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択可能.1の場合はOpenCVでカメラが開き,スペースキーで撮影.2の場合はhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/fruits.jpg とhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/messi5.jpgとhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/aero3.jpgとhttps://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpgからinput()で選択)、出力: 画像分類結果と中間層特徴マップ可視化
#   処理手順: 1.画像読み込み 2.前処理(リサイズ、正規化) 3.forward_intermediates()で中間層特徴抽出 4.分類結果取得 5.特徴マップ可視化
#   前処理、後処理: 前処理: 画像リサイズ(224x224)、テンソル正規化、後処理: 特徴マップの平均チャンネル計算と可視化変換
#   追加処理: 中間層特徴マップの平均チャンネル計算と可視化による解釈性向上
#   調整を必要とする設定値: 利用可能なモデル(ViT Tiny/Base/Large)から選択、intermediate_indices(抽出する中間層のインデックス)
# 将来方策: 複数の中間層比較機能、異なるViTモデルでの性能比較機能
# その他の重要事項: 特徴マップは平均チャンネルで可視化、分類確信度表示
# ===== 前準備(推奨) =====
# pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
# pip install -U "timm>=1.0" pillow opencv-python requests

# ===== 設定値 =====
IMAGE_SIZE = 224
FONT_SIZE = 20
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'  # Windows以外では自動でフォールバック
TEXT_COLOR = (255, 255, 255)

# 追加設定:ヒートマップ可視化設定
TOP_PERCENT = 20.0      # ヒートマップの上位20%のみ可視化
ALPHA = 0.6             # ヒートマップの最大不透明度

import os
import cv2
import tkinter as tk
from tkinter import filedialog
import urllib.request
import time
from datetime import datetime
import math
import torch
import timm
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import json
import requests

# GPU/CPU自動選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')

# ===== 利用可能なモデル設定 =====
# ※ モデルカード: HF timm/vit_* に準拠
available_models = {
    '0': {
        'name': 'vit_tiny_patch16_224.augreg_in21k_ft_in1k',
        'desc': 'ViT Tiny (224)',
        'indices': [3, 6, 9]
    },
    '1': {
        'name': 'vit_base_patch16_224.augreg2_in21k_ft_in1k',
        'desc': 'ViT Base (224, augreg2)',
        'indices': [3, 6, 9]
    },
    '2': {
        'name': 'vit_large_patch16_224.augreg_in21k_ft_in1k',
        'desc': 'ViT Large (224)',
        'indices': [6, 12, 18]
    }
}

def _safe_font(path: str, size: int):
    try:
        return ImageFont.truetype(path, size)
    except Exception:
        return ImageFont.load_default()

def put_japanese_text(img_bgr: np.ndarray, text: str, position, font_size=FONT_SIZE, color=TEXT_COLOR):
    font = _safe_font(FONT_PATH, font_size)
    img_pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(img_pil)
    draw.text(position, text, font=font, fill=color)
    return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)

def load_imagenet_classes():
    """
    ImageNet-1kクラス名を取得。
    """
    url = "https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json"
    r = requests.get(url, timeout=10)
    r.raise_for_status()
    idx_map = json.loads(r.text)
    return {int(k): v[1] for k, v in idx_map.items()}

def select_model():
    print("\n=== モデル選択 ===")
    for k, v in available_models.items():
        print(f"{k}: {v['desc']}  ->  {v['name']}  indices={v['indices']}")
    while True:
        choice = input("モデルを選択してください (0-2): ")
        if choice in available_models:
            m = available_models[choice]
            print(f"選択されたモデル: {m['name']}")
            return m['name'], m['indices']
        print("無効な選択です。0-2の範囲で選択してください。")

def build_preprocess(model):
    """
    timm公式の前処理パイプラインでモデル設定に従う(解像度/正規化/補間等)
    """
    cfg = timm.data.resolve_model_data_config(model)
    tfm = timm.data.create_transform(**cfg, is_training=False)
    return tfm

def _to_pil_from_bgr(img_bgr: np.ndarray) -> Image.Image:
    return Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))

def _norm_0_255(x: np.ndarray) -> np.ndarray:
    x = x - x.min()
    if x.max() > 0:
        x = x / x.max()
    return (x * 255.0).astype(np.uint8)

def feature_to_heatmap(feat: torch.Tensor, model, image_hw):
    """
    forward_intermediates() で得た1つの中間特徴テンソルをヒートマップ(np.uint8 HxW)に変換。
    入力 feat 形状の想定:
      - [B, C, H, W] : Conv/ViT(マップ成形済み) -> チャネル平均
      - [B, N, C]    : ViT(トークン列)          -> パッチグリッドに整形してチャネル平均
    """
    with torch.no_grad():
        if feat.dim() == 4:  # [B, C, H, W]
            m = feat.mean(dim=1)[0].detach().cpu().numpy()  # [H, W]
        elif feat.dim() == 3:  # [B, N, C]
            tokens = feat[0]  # [N, C]
            n = tokens.shape[0]
            # グリッドサイズをモデルから取得できれば使用
            grid_size = getattr(getattr(model, 'patch_embed', None), 'grid_size', None)
            if grid_size is not None:
                H, W = int(grid_size[0]), int(grid_size[1])
                # CLSトークンが含まれる場合は先頭1個を除外
                if n == 1 + H * W:
                    tokens = tokens[1:]
            else:
                # 推定(最も近い正方)
                side = int(math.sqrt(n))
                # CLS推定
                if side * side + 1 == n:
                    tokens = tokens[1:]
                    n = n - 1
                side = int(math.sqrt(tokens.shape[0]))
                H = W = side
            m = tokens.mean(dim=1).reshape(H, W).detach().cpu().numpy()
        else:
            raise RuntimeError(f"未知の特徴形状: {tuple(feat.shape)}")

    hm = _norm_0_255(m)
    hm = cv2.resize(hm, (image_hw[1], image_hw[0]), interpolation=cv2.INTER_CUBIC)
    return hm  # uint8 [H, W]

def overlay_heatmap_on_image_soft_topk(img_bgr: np.ndarray, heatmap_gray: np.ndarray, top_percent=20.0, alpha=0.6):
    """
    heatmap_grayの上位top_percent%のみ可視化し、閾値以上で線形に不透明度を上げてブレンドする。
    """
    h = heatmap_gray.astype(np.float32)
    vmax = float(h.max())
    if vmax <= 0:
        return img_bgr
    # 上位20%のみ表示 => 80パーセンタイルを閾値に設定
    perc = max(0.0, min(100.0, 100.0 - float(top_percent)))
    t = np.percentile(h, perc)
    if vmax <= t:
        return img_bgr
    alpha_map = np.clip((h - t) / (vmax - t), 0.0, 1.0) * float(alpha)
    alpha_map_3 = np.repeat(alpha_map[:, :, None], 3, axis=2)

    heatmap_color = cv2.applyColorMap(heatmap_gray, cv2.COLORMAP_JET).astype(np.float32)
    base = img_bgr.astype(np.float32)
    out = (base * (1.0 - alpha_map_3) + heatmap_color * alpha_map_3).astype(np.uint8)
    return out

def get_intermediates_and_logits(model, x, indices):
    """
    timm v1.0+ forward_intermediates() を使用。
    戻り値: (final_feat, intermediates(list[Tensor]), logits)
    """
    if not hasattr(model, 'forward_intermediates'):
        raise RuntimeError("このtimmバージョンはforward_intermediates()をサポートしていません。timm>=1.0 に更新してください。")
    final_feat, intermediates = model.forward_intermediates(x, indices=indices)  # 公式仕様
    logits = model.forward_head(final_feat)
    return final_feat, intermediates, logits

def format_topk(logits, k=5):
    prob = torch.softmax(logits, dim=1)[0]
    top_p, top_i = torch.topk(prob, k)
    return [(int(i), float(p)) for p, i in zip(top_p, top_i)]

# ===== 静止画処理テンプレ統一 =====
results_log = []

def image_processing(img_bgr: np.ndarray):
    """
    入力: BGR画像
    出力: (可視化画像BGR, 結果文字列, 現在時刻)
    """
    global preprocess, model, idx_to_labels, chosen_indices
    now = time.time()

    pil_img = _to_pil_from_bgr(img_bgr)
    x = preprocess(pil_img).unsqueeze(0).to(device)

    with torch.inference_mode():
        final_feat, intermediates, logits = get_intermediates_and_logits(model, x, indices=chosen_indices)

    # 最も深いインデックスのヒートマップをオーバレイ表示(上位20%のみソフト閾値で可視化)
    last_feat = intermediates[-1] if isinstance(intermediates, (list, tuple)) else intermediates
    hm = feature_to_heatmap(last_feat, model, (img_bgr.shape[0], img_bgr.shape[1]))
    vis = overlay_heatmap_on_image_soft_topk(img_bgr, hm, top_percent=TOP_PERCENT, alpha=ALPHA)

    # Top-5表示テキスト
    top5 = format_topk(logits, k=5)
    lines = []
    for rank, (cls_idx, p) in enumerate(top5, start=1):
        name = idx_to_labels.get(cls_idx, str(cls_idx))
        lines.append(f"{rank}. {name}: {p*100:.1f}%")
    txt = "分類結果 Top-5\n" + "\n".join(lines) + f"\nindices={chosen_indices}"
    vis = put_japanese_text(vis, txt, (10, 10), font_size=18, color=(255, 255, 255))

    # ログ用はTop-5表記
    result_str = "Top5=" + ", ".join([f"{idx_to_labels.get(i, str(i))}:{p*100:.1f}%" for i, p in top5]) + f", indices={chosen_indices}"
    return vis, result_str, now

def process_and_display_images(image_sources, source_type):
    display_index = 1
    for source in image_sources:
        img = cv2.imread(source) if source_type == 'file' else source
        if img is None:
            continue
        cv2.imshow(f'Image_{display_index}', img)
        processed_img, result, current_time = image_processing(img)
        cv2.imshow(f'ViT中間特徴ヒートマップ+分類_{display_index}', processed_img)
        print(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], result)
        results_log.append(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + " " + result)
        display_index += 1

# ===== メイン(静止画/カメラ、ユーザーガイダンス含む) =====
print("\n=== 概要 ===")
print("timmのforward_intermediates()でViTの中間特徴を取得し、ヒートマップを重畳して可視化・分類結果を表示します。")
print("操作方法:")
print("  0: 画像ファイルを選択して処理")
print("  1: カメラ映像からスペースキーで静止画をキャプチャして処理(qで終了)")
print("  2: サンプル画像URL(数枚)をダウンロードして処理")
model_name, chosen_indices = select_model()

# モデルと前処理を準備
model = timm.create_model(model_name, pretrained=True).to(device).eval()
preprocess = build_preprocess(model)
idx_to_labels = load_imagenet_classes()

print("\n=== 入力選択 ===")
print("0: 画像ファイル")
print("1: カメラ")
print("2: サンプル画像")
choice = input("選択: ")

try:
    if choice == '0':
        root = tk.Tk()
        root.withdraw()
        if not (paths := filedialog.askopenfilenames()):
            raise SystemExit
        process_and_display_images(paths, 'file')
        cv2.waitKey(0)

    elif choice == '1':
        # OpenCVでのカメラ開始(統一書式)
        cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
        if not cap.isOpened():
            cap = cv2.VideoCapture(0)
        cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                cv2.imshow('Camera', frame)
                key = cv2.waitKey(1) & 0xFF
                if key == ord(' '):
                    processed_img, result, current_time = image_processing(frame)
                    cv2.imshow('ViT中間特徴ヒートマップ+分類', processed_img)
                    print(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], result)
                    results_log.append(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + " " + result)
                elif key == ord('q'):
                    break
        finally:
            cap.release()

    else:
        urls = [
            "https://raw.githubusercontent.com/opencv/opencv/master/samples/data/fruits.jpg",
            "https://raw.githubusercontent.com/opencv/opencv/master/samples/data/messi5.jpg",
            "https://raw.githubusercontent.com/opencv/opencv/master/samples/data/aero3.jpg",
            "https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg"
        ]
        downloaded_files = []
        for i, url in enumerate(urls):
            try:
                urllib.request.urlretrieve(url, f"sample_{i}.jpg")
                downloaded_files.append(f"sample_{i}.jpg")
            except Exception:
                print(f"画像のダウンロードに失敗しました: {url}")
        process_and_display_images(downloaded_files, 'file')
        cv2.waitKey(0)

finally:
    print('\n=== プログラム終了 ===')
    cv2.destroyAllWindows()
    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('=== 結果 ===\n')
            f.write(f'使用デバイス: {str(device).upper()}\n')
            if device.type == 'cuda':
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            for line in results_log:
                f.write(line + '\n')
        print(f'\n処理結果をresult.txtに保存しました')