SAM2による前景背景分離(ソースコード)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install opencv-python numpy pillow gitpython requests

SAM2による前景背景分離プログラム

ソースコード


"""
SAM2による前景背景分離プログラム

特徴技術名:SAM2(Segment Anything Model 2)
出典:Ravi, N., Gabeur, V., Hu, Y. T., Hu, R., Ryali, C., Ma, T., ... & Feichtenhofer, C. (2024). SAM 2: Segment anything in images and videos. arXiv preprint arXiv:2408.00714.
特徴機能:promptable visual segmentation - 画像と動画の両方で、ユーザーからのプロンプト(点、ボックス、マスク等)に基づいてリアルタイムでオブジェクトをセグメンテーションする機能。本プログラムでは点プロンプトのみを使用する。

学習済みモデル:
- sam2_hiera_large.pt:最高精度モデル(636M parameters)- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
- sam2_hiera_base_plus.pt:バランス型モデル(308M parameters)- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt
- sam2_hiera_small.pt:軽量モデル - https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt
- sam2_hiera_tiny.pt:最軽量モデル(91M parameters)- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt

方式設計:
- 関連利用技術:
  - OpenCV:カメラ入力と画像処理(リアルタイム表示、マウスイベント処理)
  - PyTorch:深層学習フレームワーク(SAM2モデルの実行)
  - gitpython:Gitリポジトリの自動クローン(SAM2公式実装の取得)
  - PIL:画像処理ライブラリ(日本語テキスト描画)
  - NumPy:数値計算(配列操作、座標処理)
- 入力と出力:
  入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択。0:動画ファイルの場合はtkinterでファイル選択。1の場合はOpenCVでカメラが開く。2の場合はhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.aviを使用)
  出力: リアルタイム前景背景分離結果をOpenCV画面で表示。元画像、前景画像、マスク画像を同時表示。OpenCV画面内に処理結果をテキストで表示。各フレームごとに処理結果を表示。プログラム終了時に処理結果をresult.txtファイルに保存し、「result.txtに保存」したことをprint()で表示。プログラム開始時に、プログラムの概要、ユーザが行う必要がある操作をprint()で表示。
- 処理手順:
  1. gitpythonによるSAM2リポジトリの自動ダウンロード
  2. 事前学習済みモデル重みの自動取得(選択モデルのみ)
  3. SAM2モデルの初期化とメモリ配置
  4. 動画ソースからのフレーム取得
  5. マウスクリックによるプロンプトポイント設定(左クリック:前景,右クリック:背景)
  6. SAM2によるセグメンテーション実行
  7. 前景抽出とマスク生成
  8. リアルタイム結果表示
- 前処理:RGB色空間変換(OpenCVのBGRからRGBへの変換)
- 後処理:セグメンテーションマスクのスコア評価による最適マスク選択、前景領域の抽出処理
- 追加処理:メモリ効率化のためのGPU/CPU自動選択、マウスコールバックによるインタラクティブなプロンプト設定

調整を必要とする設定値:
- model_cfg:モデル設定ファイル('sam2_hiera_l.yaml'等)- 使用するSAM2モデルの種類を決定
- camera_id:カメラデバイス番号(通常は0)- 複数カメラ接続時の選択

将来方策:プログラム内でのモデル性能評価機能を追加し、利用可能なGPUメモリ量に基づいて最適なmodel_cfg(tiny/small/base/large)を自動選択する機能の実装

その他の重要事項:
- Windows環境専用設計
- CUDA対応GPU推奨(CPU動作も可能だが処理速度が低下)
- 初回実行時にモデルおよびリポジトリのダウンロードが発生(選択するモデルにより総量は数百MB〜約1GB程度)

前準備:
- pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
- pip install opencv-python numpy pillow gitpython requests
"""

import os
import sys
import cv2
import numpy as np
import torch
from pathlib import Path
import git
from PIL import Image, ImageDraw, ImageFont
import urllib.request
import tkinter as tk
from tkinter import filedialog
import time
from datetime import datetime

# GPU/CPU自動選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')

# 設定値
base_dir = './sam2_models'
repo_url = 'https://github.com/facebookresearch/segment-anything-2.git'

# グローバル変数
predictor = None
point_coords_list = []
point_labels_list = []
frame_count = 0
results_log = []

# マウスコールバック用のデータクラス
class MouseCallbackData:
    def __init__(self):
        self.frame_width = 0

callback_data = MouseCallbackData()

def download_sam2_repository():
    repo_dir = Path(base_dir) / 'segment-anything-2'
    if repo_dir.exists():
        print('SAM2リポジトリは既に存在します')
        return str(repo_dir)

    Path(base_dir).mkdir(exist_ok=True)
    try:
        git.Repo.clone_from(repo_url, repo_dir)
        print('SAM2リポジトリのダウンロード完了')
        return str(repo_dir)
    except Exception as e:
        print(f'SAM2リポジトリのダウンロードに失敗しました: {e}')
        exit()

def download_model_weights():
    model_urls = {
        'sam2_hiera_tiny.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt',
        'sam2_hiera_small.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt',
        'sam2_hiera_base_plus.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt',
        'sam2_hiera_large.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'
    }

    checkpoints_dir = Path(base_dir) / 'checkpoints'
    checkpoints_dir.mkdir(exist_ok=True)

    # 選択されたモデルのみダウンロード
    if model_file in model_urls:
        model_path = checkpoints_dir / model_file
        if model_path.exists():
            print(f'モデルは既に存在: {model_file}')
        else:
            print(f'モデルをダウンロード中: {model_file}')
            try:
                urllib.request.urlretrieve(model_urls[model_file], model_path)
                print(f'モデルダウンロード完了: {model_file}')
            except Exception as e:
                print(f'モデル重みのダウンロードに失敗しました: {e}')
                exit()

def initialize_sam2():
    global predictor

    repo_dir = download_sam2_repository()
    download_model_weights()

    # sys.pathにSAM2ディレクトリを追加
    sys.path.insert(0, repo_dir)

    # PYTHONPATHにSAM2リポジトリを追加
    if 'PYTHONPATH' in os.environ:
        os.environ['PYTHONPATH'] = f"{repo_dir}{os.pathsep}{os.environ['PYTHONPATH']}"
    else:
        os.environ['PYTHONPATH'] = repo_dir

    # 作業ディレクトリをSAM2リポジトリに変更
    original_cwd = os.getcwd()
    os.chdir(repo_dir)

    try:
        from sam2.build_sam import build_sam2
        from sam2.sam2_image_predictor import SAM2ImagePredictor

        print(f'使用デバイス: {device}')

        # チェックポイントファイルの絶対パスを設定
        checkpoint_path = str(Path(original_cwd) / base_dir / 'checkpoints' / model_file)

        sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
        predictor = SAM2ImagePredictor(sam2_model)
        print('SAM2モデルの初期化完了')

    except Exception as e:
        print(f'SAM2モデルの初期化に失敗しました: {e}')
        exit()
    finally:
        # 作業ディレクトリを元に戻す
        os.chdir(original_cwd)

def video_frame_processing(frame):
    global frame_count, predictor, point_coords_list, point_labels_list, callback_data
    current_time = time.time()
    frame_count += 1

    # 元画像の幅を保存
    callback_data.frame_width = frame.shape[1]

    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    if len(point_coords_list) == 0:
        # ポイントが指定されていない場合は元画像をそのまま返す
        result = f'ポイント未指定'
        return frame, result, current_time

    try:
        predictor.set_image(frame_rgb)

        # 複数のプロンプトポイントを配列として渡す
        point_coords = np.array(point_coords_list)
        point_labels = np.array(point_labels_list)

        masks, scores, logits = predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=True
        )

        # 最高スコアのマスクを選択(複数マスク候補から最適なものを選択)
        best_mask_idx = np.argmax(scores)
        mask = masks[best_mask_idx]

        # マスクをboolean型に変換
        mask_bool = mask.astype(bool)

        foreground = frame_rgb.copy()
        foreground[~mask_bool] = [0, 0, 0]
        foreground_bgr = cv2.cvtColor(foreground, cv2.COLOR_RGB2BGR)

        # マスクを3チャンネルに変換(uint8型に変換)
        mask_display = (mask_bool.astype(np.uint8) * 255)
        mask_3ch = cv2.cvtColor(mask_display, cv2.COLOR_GRAY2BGR)

        # 結果を横に並べて表示
        combined = np.hstack((frame, foreground_bgr, mask_3ch))

        # 日本語テキスト表示用の準備
        FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
        FONT_SIZE = 20

        font = ImageFont.truetype(FONT_PATH, FONT_SIZE)
        img_pil = Image.fromarray(cv2.cvtColor(combined, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(img_pil)

        # テキスト情報を画像に描画
        draw.text((10, 30), "元画像", font=font, fill=(255, 255, 255))
        draw.text((frame.shape[1] + 10, 30), "前景", font=font, fill=(255, 255, 255))
        draw.text((frame.shape[1] * 2 + 10, 30), "マスク", font=font, fill=(255, 255, 255))

        # デバッグ情報を画面に表示
        debug_text = f'ポイント数: {len(point_coords_list)}'
        draw.text((10, combined.shape[0] - 40), debug_text, font=font, fill=(0, 255, 255))

        combined = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)

        # プロンプトポイントを描画(元画像部分に描画)
        for i, (coord, label) in enumerate(zip(point_coords_list, point_labels_list)):
            color = (0, 255, 0) if label == 1 else (0, 0, 255)
            # 元画像の座標にポイントを描画
            cv2.circle(combined, tuple(coord), 5, color, -1)
            # ポイント番号も表示
            cv2.putText(combined, str(i+1), (coord[0]+8, coord[1]-8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

        result = f'前景背景分離実行中 - ポイント数: {len(point_coords_list)}, スコア: {scores[best_mask_idx]:.3f}'
        return combined, result, current_time

    except Exception as e:
        result = f'前景抽出エラー: {e}'
        return frame, result, current_time

def mouse_callback(event, x, y, flags, param):
    global point_coords_list, point_labels_list

    # パラメータから元画像の幅を取得
    if param and hasattr(param, 'frame_width'):
        frame_width = param.frame_width
    else:
        return

    # 元画像の範囲内のクリックのみ処理
    if x >= frame_width:
        return

    if event == cv2.EVENT_LBUTTONDOWN:
        point_coords_list.append([x, y])
        point_labels_list.append(1)
        print(f'前景ポイント追加: ({x}, {y}) - 総数: {len(point_coords_list)}')
    elif event == cv2.EVENT_RBUTTONDOWN:
        point_coords_list.append([x, y])
        point_labels_list.append(0)
        print(f'背景ポイント追加: ({x}, {y}) - 総数: {len(point_coords_list)}')
    elif event == cv2.EVENT_MBUTTONDOWN:
        # 中ボタンクリックで全てのポイントをクリア
        point_coords_list.clear()
        point_labels_list.clear()
        print('全ポイントクリア')

print('SAM2による前景背景分離プログラム')
print('概要: 動画からSAM2を使用してリアルタイムで前景背景を分離します')
print('')
print('基本操作: 前景ポイント1つ + 背景ポイント1つを指定')
print('- 前景ポイント: 分離したいオブジェクトの上で左クリック(緑色の丸)')
print('- 背景ポイント: 除外したい部分で右クリック(赤色の丸)')
print('- 精度向上のため複数ポイント指定可能')
print('')
print('操作方法:')
print('- 左クリック: 前景ポイント追加(元画像部分のみ)')
print('- 右クリック: 背景ポイント追加(元画像部分のみ)')
print('- 中ボタンクリック: 全ポイントクリア')
print('- q: 終了')
print('')
print('注意事項:')
print('- 初回実行時はモデルのダウンロードに時間がかかります')
print('- GPU使用時は処理が高速化されます')
print('')

print('モデル選択:')
print('1: sam2_hiera_tiny.pt (155.9MB, 91M parameters)')
print('2: sam2_hiera_small.pt (180MB, parameters未公開)')
print('3: sam2_hiera_base_plus.pt (320MB, 308M parameters)')
print('4: sam2_hiera_large.pt (900MB, 636M parameters)')

model_choice = input('モデル選択 (1-4): ')

if model_choice == '1':
    model_file = 'sam2_hiera_tiny.pt'
    model_cfg = 'sam2_hiera_t.yaml'
elif model_choice == '2':
    model_file = 'sam2_hiera_small.pt'
    model_cfg = 'sam2_hiera_s.yaml'
elif model_choice == '3':
    model_file = 'sam2_hiera_base_plus.pt'
    model_cfg = 'sam2_hiera_b+.yaml'
elif model_choice == '4':
    model_file = 'sam2_hiera_large.pt'
    model_cfg = 'sam2_hiera_l.yaml'
else:
    print('無効な選択です')
    exit()

print(f'選択されたモデル: {model_file}')
print('')

initialize_sam2()

print("0: 動画ファイル")
print("1: カメラ")
print("2: サンプル動画")

choice = input("選択: ")

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename()
    if not path:
        exit()
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    if not cap.isOpened():
        cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
else:
    # サンプル動画ダウンロード・処理
    SAMPLE_URL = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'
    SAMPLE_FILE = 'vtest.avi'
    urllib.request.urlretrieve(SAMPLE_URL, SAMPLE_FILE)
    cap = cv2.VideoCapture(SAMPLE_FILE)

if not cap.isOpened():
    print('動画ファイル・カメラを開けませんでした')
    exit()

# メイン処理
print('\n=== 動画処理開始 ===')
print('操作方法:')
print('  q キー: プログラム終了')

# ウィンドウ作成とマウスコールバック設定
cv2.namedWindow('SAM2前景背景分離', cv2.WINDOW_AUTOSIZE)
cv2.setMouseCallback('SAM2前景背景分離', mouse_callback, callback_data)

try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        MAIN_FUNC_DESC = "SAM2前景背景分離"
        processed_frame, result, current_time = video_frame_processing(frame)
        cv2.imshow(MAIN_FUNC_DESC, processed_frame)
        if choice == '1':  # カメラの場合
            print(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], result)
        else:  # 動画ファイルの場合
            print(frame_count, result)
        results_log.append(result)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
finally:
    print('\n=== プログラム終了 ===')
    cap.release()
    cv2.destroyAllWindows()
    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('=== 結果 ===\n')
            f.write(f'処理フレーム数: {frame_count}\n')
            f.write(f'使用デバイス: {str(device).upper()}\n')
            if device.type == 'cuda':
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            f.write('\n')
            f.write('\n'.join(results_log))
        print(f'\n処理結果をresult.txtに保存しました')
>