MambaOut による画像分類

目次

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

ライブラリのインストール:

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install opencv-python pillow timm

MambaOut による画像分類プログラム

概要

本プログラムは、動画の各フレームを1000種類のカテゴリに分類する。ConvNeXtアーキテクチャを用いて、リアルタイムで画像認識を行い、認識結果の確信度とともに上位5つの候補を提示する。

主要技術

1. ConvNeXt
畳み込みニューラルネットワーク(CNN)の一種である[1]。Vision Transformerの設計思想を取り入れながら、純粋なCNNとして実装されている。7×7の大きな畳み込みカーネル、深さ方向分離畳み込み、LayerNormalizationなどの特徴を持つ。

2. MambaOut
Vision Mambaから状態空間モデル(SSM)を除去した結果、ConvNeXtと同一のアーキテクチャになることを示した研究である[2]。この発見により、視覚タスクにおいてSSMが必須ではないことが示された。

参考文献

ソースコード


# プログラム名: MambaOut 画像分類プログラム
# 特徴技術名: MambaOut
# 出典: Yu, W., & Wang, X. (2025). MambaOut: Do We Really Need Mamba for Vision? In CVPR.
# 特徴機能: Gated CNNブロックを積み重ねた階層的アーキテクチャによる画像分類。7x7カーネルの深さ方向畳み込みでトークンミキシングを行う。ImageNet画像分類に関する性能が報告されている。
# 学習済みモデル: MambaOut-Tiny(5M parameters)、MambaOut-Small(25M)、MambaOut-Base(76M)、MambaOut-Kobe(11M)が利用可能。timmライブラリから事前学習済みモデルをダウンロード可能。
# 方式設計:
#   関連利用技術: timm(PyTorch Image Models)、PyTorch、OpenCV、PIL
#   入出力:
#     入力: 0=動画ファイル(tkinterで選択)、1=カメラ(OpenCVで取得)、2=サンプル動画(https://github.com/opencv/opencv/blob/master/samples/data/vtest.avi)
#     出力: 処理結果をOpenCVウィンドウにリアルタイム表示し、画面内テキストで結果を提示。各フレームごとにprint()で結果を出力し、終了時にresult.txtに保存する。
#   処理手順: 1.フレーム取得→2.224x224へリサイズ→3.ImageNet正規化→4.MambaOutモデルで推論→5.softmaxで確率計算→6.Top-K抽出(K=5既定)→7.結果を画像に描画して表示。
#   前処理・後処理:
#     前処理: Resize(224,224)、ToTensor、正規化(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
#     後処理: softmaxで確率計算、Top-K抽出(k=5)
#   調整項目: MODEL_NAME('mambaout_tiny.in1k'、'mambaout_small.in1k'、'mambaout_base.in1k'、'mambaout_kobe.in1k' 等)
# 備考: timmバージョン0.6.11以上が必要(前提)。Windows環境で 'C:/Windows/Fonts/meiryo.ttc' を使用。

import os
import time
import cv2
import numpy as np
import torch
import timm
import urllib.request
import tkinter as tk
from tkinter import filedialog
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
from datetime import datetime

# 定数
MAMBAOUT_MODELS = {
    'mambaout_tiny.in1k': {'params': '5M', 'input_size': '224x224'},
    'mambaout_small.in1k': {'params': '25M', 'input_size': '224x224'},
    'mambaout_base.in1k': {'params': '76M', 'input_size': '224x224'},
    'mambaout_kobe.in1k': {'params': '11M', 'input_size': '224x224'},
    'mambaout_femto.in1k': {'params': '3M', 'input_size': '224x224'}
}
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
FONT_SIZE = 18
RANDOM_SEED = 42
RESULT_FILE = 'result.txt'
TOP_K = 5

# 利用可能モデルの検出
def detect_available_models():
    available = []
    for name in MAMBAOUT_MODELS.keys():
        try:
            timm.create_model(name, pretrained=False)
            available.append(name)
        except Exception:
            continue
    return available

def select_model():
    available = detect_available_models()
    if not available:
        print('利用可能なMambaOutモデルが見つかりません')
        exit()
    print('利用可能なMambaOutモデル:')
    for i, name in enumerate(available):
        info = MAMBAOUT_MODELS.get(name, {})
        print(f'{i}: {name} ({info.get("params","-")} params, {info.get("input_size","-")})')
    choice = input('モデルを選択 (デフォルト: 0): ').strip()
    try:
        idx = 0 if choice == '' else int(choice)
        if 0 <= idx < len(available):
            return available[idx]
    except Exception:
        pass
    print('無効な入力です。デフォルトモデルを使用します')
    return available[0]

# ImageNet クラス名の取得(オンライン失敗時は番号ラベル)
try:
    url = 'https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt'
    with urllib.request.urlopen(url) as resp:
        IMAGENET_CLASSES = [line.decode('utf-8').strip() for line in resp.readlines()]
    print('ImageNetクラス名を取得しました')
except Exception as e:
    print(f'ImageNetクラス名の取得に失敗しました: {e}')
    IMAGENET_CLASSES = [f'class_{i}' for i in range(1000)]

print('MambaOut リアルタイム動画分類プログラム')
print('')
print('概要: MambaOutモデルを用いたリアルタイム画像分類(ImageNet 1000クラス)')
print('操作方法:')
print('  q キー: プログラム終了')
print('  0: 動画ファイル, 1: カメラ, 2: サンプル動画')
print('注意: ウィンドウがアクティブな状態でキー入力を行うこと')
print('')

# 乱数シード
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# GPU/CPU自動選択(指定どおりに統一)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')

# モデル選択とロード
MODEL_NAME = select_model()
print(f'使用モデル: {MODEL_NAME}')
print('モデルをロード中...')
try:
    model = timm.create_model(MODEL_NAME, pretrained=True).to(device)
    model.eval()
    print('モデルのロードが完了しました')
except Exception as e:
    print(f'モデルの読み込みに失敗しました: {e}')
    exit()

# 前処理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# フォント
font = ImageFont.truetype(FONT_PATH, FONT_SIZE)

# クラス名
class_names = IMAGENET_CLASSES

# グローバルなカウンタとログ
frame_count = 0
results_log = []

def _confidence_color(prob):
    if prob >= 0.7:
        return (0, 255, 0)      # 緑
    elif prob >= 0.5:
        return (0, 255, 255)    # 黄
    elif prob >= 0.3:
        return (0, 165, 255)    # オレンジ
    else:
        return (0, 0, 255)      # 赤

def video_frame_processing(frame):
    global frame_count
    current_time = time.time()
    frame_count += 1

    # 推論
    img_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    tensor = transform(img_pil).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(tensor)              # 形状: [1, 1000]
        probs = torch.softmax(outputs[0], dim=0)

    topk_prob, topk_idx = torch.topk(probs, k=min(TOP_K, probs.numel()))
    # 描画
    draw = ImageDraw.Draw(img_pil)
    draw.text((10, 10), 'MambaOut分類', font=font, fill=(0, 255, 0))
    draw.text((10, 35), f'上位{TOP_K}予測:', font=font, fill=(0, 255, 0))

    label_pairs = []
    for i, (p, idx) in enumerate(zip(topk_prob, topk_idx)):
        ci = idx.item()
        name = class_names[ci] if ci < len(class_names) else f'class_{ci}'
        prob = float(p.item())
        label_pairs.append(f'{i+1}.{name}({prob:.3f})')
        draw.text((10, 60 + i * 25), f'{i+1}. {name} ({prob:.3f})', font=font, fill=_confidence_color(prob))

    # 追加情報
    draw.text((10, 60 + TOP_K * 25 + 10), f'フレーム: {frame_count}', font=font, fill=(0, 255, 0))

    processed = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    # 整形結果(1行)
    result_str = 'Top-{}: {}'.format(TOP_K, ', '.join(label_pairs))
    return processed, result_str, current_time

# 入力選択
print('0: 動画ファイル')
print('1: カメラ')
print('2: サンプル動画')
choice = input('選択: ').strip()

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename(title='動画ファイルを選択')
    if not path:
        exit()
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    if not cap.isOpened():
        cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
else:
    SAMPLE_URL = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'
    SAMPLE_FILE = 'vtest.avi'
    try:
        urllib.request.urlretrieve(SAMPLE_URL, SAMPLE_FILE)
    except Exception as e:
        print(f'サンプル動画のダウンロードに失敗しました: {e}')
        exit()
    cap = cv2.VideoCapture(SAMPLE_FILE)

if not cap.isOpened():
    print('動画ファイル・カメラを開けませんでした')
    exit()

# メイン処理
MAIN_FUNC_DESC = 'MambaOut画像分類'
print('\n=== 動画処理開始 ===')
print('操作方法:')
print('  q キー: プログラム終了')

try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        processed_frame, result, current_time = video_frame_processing(frame)
        cv2.imshow(MAIN_FUNC_DESC, processed_frame)

        if choice == '1':  # カメラ
            print(datetime.fromtimestamp(current_time).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3], result)
        else:              # 動画ファイル or サンプル
            print(frame_count, result)

        results_log.append(result)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
finally:
    print('\n=== プログラム終了 ===')
    cap.release()
    cv2.destroyAllWindows()
    if results_log:
        with open(RESULT_FILE, 'w', encoding='utf-8') as f:
            f.write('=== 結果 ===\n')
            f.write(f'処理フレーム数: {frame_count}\n')
            f.write(f'使用デバイス: {str(device).upper()}\n')
            if device.type == 'cuda':
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            f.write('\n')
            f.write('\n'.join(results_log))
        print('\n処理結果をresult.txtに保存しました')

使用方法

  1. 上記のプログラムを実行する
  2. 動作確認

    • Webカメラ映像が表示される
    • 映像上にリアルタイムで分類結果が表示される
    • 上位5位の分類結果と信頼度が表示される
  3. 終了方法:映像ウィンドウで 'q' キーを押す。

実験・探求のアイデア

AIモデル選択の実験

実験要素

体験・実験・探求のアイデア