MediaPipe Face Landmarkerによる瞳孔と虹彩追跡(ソースコードと実行結果)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install mediapipe opencv-python numpy pillow

MediaPipe Face Landmarkerによる瞳孔と虹彩追跡プログラム

概要

このプログラムは、MediaPipe Face Landmarkerを用いて動画から顔の虹彩(瞳孔周辺の色素部分)をリアルタイムで検出・追跡するシステムである。478個の顔ランドマークから虹彩位置を特定し、眼球運動の詳細な分析を行う。カルマンフィルタによる位置推定の平滑化により、安定した追跡を実現している。

主要技術

MediaPipe Face Landmarker

Googleが開発した顔ランドマーク検出技術で、機械学習モデルにより顔の478個の特徴点を検出する[1]。このモデルは顔の輪郭、眉、目、鼻、口、虹彩を含む詳細な顔構造を捉え、各ランドマークの3次元座標を推定する。特に虹彩追跡においては、インデックス468-472(右眼)と473-477(左眼)の10個のランドマークポイントを使用する。

カルマンフィルタ

1960年にRudolf E. Kálmánが開発した再帰的推定アルゴリズム[2]。本プログラムでは2次元カルマンフィルタ(状態4次元:x, y, vx, vy)を実装し、虹彩位置の予測と観測値の融合により、ノイズを除去した滑らかな追跡を実現している。プロセスノイズ(Q=0.01)と観測ノイズ(R=0.1)のパラメータ調整により、追跡精度を最適化している。

眼球運動分析アルゴリズム

サッケード(急速眼球運動)、スムースパシュート(滑動性追跡眼球運動)、固視微動の3種類の眼球運動パターンを検出・分析する。サッケード検出では速度閾値(300px/s)と最小持続時間(20ms)を基準とし、スムースパシュートでは速度範囲(30-200px/s)内での安定性を評価する[3]。固視微動はゼロクロス法により振幅と周波数を算出する。

技術的特徴

本システムは、眼球運動の正規化処理に独自の工夫を実装している。外眼角間と内眼角間の距離の平均値を基準距離として使用し、個人差や撮影距離による影響を補正する。これにより、異なる撮影条件下でも一貫した眼球運動の評価が可能となる。

Eye Aspect Ratio (EAR)による瞬き検出と虹彩非検出による瞬き検出の二重検出機構を実装している。EARは目の縦横比から算出され、閾値0.2以下で瞬きと判定する[4]。この二重検出により、部分的な遮蔽や照明変化に対してロバストな瞬き検出を実現している。

リアルタイムグラフ表示機能により、視線角度、速度トレンド、瞬き頻度トレンド、動き頻度、サッケード頻度の5つの指標を30秒の時間窓で可視化する。線形回帰によるトレンド分析により、眼球運動パターンの時間的変化を定量的に評価できる。

実装の特色

初回実行時には学習済みモデル(face_landmarker.task)を自動ダウンロードし、セットアップの簡易化を図っている。

処理結果の包括的な記録機能により、フレームごとの詳細データとサッケードイベントの統計情報をresult.txtファイルに保存する。これにより、眼球運動の事後分析や研究用データとしての活用が可能である。

入力ソースの柔軟な選択機能(動画ファイル、カメラ、サンプル動画)により、様々な用途に対応している。tkinterによるファイル選択ダイアログとOpenCVによるカメラ制御を統合し、ユーザビリティを向上させている。

取得データの詳細

基本測定データ

時系列分析データ

眼球運動パターン分析データ

出力ファイル形式

result.txtファイルには、処理フレーム数、総サッケード検出数、平均サッケード持続時間、平均サッケード最大速度の統計情報と、各フレームの詳細測定値が時系列で記録される。カメラ入力時はタイムスタンプ付き、動画ファイル入力時はフレーム番号付きで出力される。このデータ形式により、眼球運動の時間的変化の詳細な分析や、異なる条件下での比較研究が可能となる。

参考文献

[1] Lugaresi, C., et al. (2019). MediaPipe: A Framework for Building Perception Pipelines. arXiv preprint arXiv:1906.08172. https://arxiv.org/abs/1906.08172

[2] Kalman, R. E. (1960). A New Approach to Linear Filtering and Prediction Problems. Journal of Basic Engineering, 82(1), 35-45. https://doi.org/10.1115/1.3662552

[3] Leigh, R. J., & Zee, D. S. (2015). The Neurology of Eye Movements (5th ed.). Oxford University Press. https://doi.org/10.1093/med/9780199969289.001.0001

[4] Soukupová, T., & Čech, J. (2016). Eye Blink Detection Using Facial Landmarks. 21st Computer Vision Winter Workshop. https://vision.fe.uni-lj.si/cvww2016/proceedings/papers/05.pdf

用語集

ソースコード


# MediaPipe Face Landmarkerによる虹彩追跡プログラム
# 特徴技術名: MediaPipe
# 出典: MediaPipe Tasks - Google
# 特徴機能: MediaPipe Face Landmarkerによる虹彩追跡。リアルタイムで動作する目の動き検出
# 学習済みモデル: Face Landmarker事前学習済みモデル(478顔ランドマーク)
# 方式設計:
#   - 関連利用技術:
#     - MediaPipe: Googleが開発したマルチプラットフォーム機械学習ソリューション
#     - OpenCV: 画像処理、カメラ制御、描画処理、動画入出力管理
#   - 入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択)、出力: OpenCV画面でリアルタイム表示、各フレームごとの処理結果表示、プログラム終了時にresult.txtファイルに保存
#   - 処理手順: 1.フレーム取得、2.MediaPipe推論実行、3.顔ランドマーク検出、4.虹彩位置計算、5.カルマンフィルタ適用、6.サッケード検出、7.スムースパシュート評価、8.固視微動分析、9.両眼視差計算、10.虹彩追跡分析、11.虹彩円描画
#   - 前処理、後処理: 前処理:MediaPipe内部で自動実行。後処理:虹彩中心座標の計算、カルマンフィルタによる平滑化、各種眼球運動指標の算出を実施
#   - 追加処理: 左右の虹彩位置を個別に追跡し表示、サッケード検出と頻度計算、スムースパシュート追跡、固視微動分析、両眼視差計算、虹彩追跡分析
#   - 調整を必要とする設定値: CONF_THRESH(顔検出信頼度閾値、デフォルト0.5)、KALMAN_Q(プロセスノイズ)、KALMAN_R(観測ノイズ)、SACCADE_THRESH(サッケード検出速度閾値、px/s基準)、SMOOTH_PURSUIT_MIN/MAX_SPEED(スムースパシュート速度範囲、px/s基準)、FIXATION_THRESH(固視判定速度閾値)
# その他の重要事項: Windows環境専用設計、初回実行時は学習済みモデルの自動ダウンロード
# 前準備:
#   - pip install mediapipe opencv-python numpy pillow

import cv2
import tkinter as tk
from tkinter import filedialog
import os
import numpy as np
import mediapipe as mp
import warnings
import time
import urllib.request
import math
from collections import deque
from PIL import Image, ImageDraw, ImageFont
from datetime import datetime

warnings.filterwarnings('ignore')

# ===== 設定・定数管理 =====
# MediaPipe設定
BaseOptions = mp.tasks.BaseOptions
FaceLandmarker = mp.tasks.vision.FaceLandmarker
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
VisionRunningMode = mp.tasks.vision.RunningMode

# モデル選択
MODEL_SIZE = '0'  # 使用するモデルサイズ(0=標準モデル)

# モデル情報
MODEL_INFO = {
    '0': {
        'name': 'Face Landmarker',
        'desc': '顔ランドマーク検出(虹彩追跡対応)',
        'url': 'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task',
        'file': 'face_landmarker.task'
    }
}

MODEL_URL = MODEL_INFO[MODEL_SIZE]['url']
MODEL_PATH = MODEL_INFO[MODEL_SIZE]['file']

# 虹彩関連のランドマークインデックス
# MediaPipe Face Landmarkerは478個のランドマークを出力
# インデックス468-472が右眼虹彩(5点)、473-477が左眼虹彩(5点)
LEFT_IRIS_INDICES = [473, 474, 475, 476, 477]  # 左虹彩(5点)
RIGHT_IRIS_INDICES = [468, 469, 470, 471, 472]  # 右虹彩(5点)

# 目のランドマークインデックス(EAR計算用)
LEFT_EYE_INDICES = [33, 160, 158, 133, 153, 144]
RIGHT_EYE_INDICES = [362, 385, 387, 263, 373, 380]

# カメラ設定
WINDOW_WIDTH = 1280
WINDOW_HEIGHT = 720
FPS = 30  # 想定フレームレート

# 検出パラメータ(調整可能)
CONF_THRESH = 0.5
EAR_THRESH = 0.2

# カルマンフィルタ
KALMAN_Q = 0.01
KALMAN_R = 0.1

# サッケード検出
SACCADE_THRESH = 300.0  # px/s基準
SACCADE_MIN_DURATION = 0.02  # 秒

# スムースパシュート検出
SMOOTH_PURSUIT_MIN_SPEED = 30.0  # px/s基準
SMOOTH_PURSUIT_MAX_SPEED = 200.0  # px/s基準
SMOOTH_PURSUIT_WINDOW = 1.0  # 秒

# 固視検出
FIXATION_THRESH = 10.0  # px/s基準
FIXATION_WINDOW = 1.0  # 秒

# 表示設定
GRAPH_TIME_WINDOW = 30  # 秒
GRAPH_WIDTH = 400
GRAPH_HEIGHT = 120
GRAPH_MARGIN = 10
TREND_WINDOW = 10  # 秒

# 日本語フォント
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
FONT_SIZE = 16

# 動き閾値
movement_threshold = 5.0  # px/s基準

# プログラム概要表示
print('=== MediaPipe虹彩追跡プログラム ===')
print('概要: リアルタイムで虹彩(目の動き)を検出し、円で表示します')
print('機能: MediaPipe Face Landmarkerによる虹彩追跡、カルマンフィルタ平滑化、サッケード検出、')
print('      スムースパシュート追跡、固視微動分析、両眼視差計算、虹彩追跡分析')
print('操作: qキーで終了')
print('出力: 各フレームごとの処理結果表示、終了時にresult.txt保存')
print()

# 算出指標の説明
print('=== 算出される指標の説明 ===')
print()
print('【基本指標】')
print('・視線角度: 虹彩中心と目の中心を結ぶ線の角度(-180~180度)')
print('・虹彩サイズ: 左右の虹彩の直径(ピクセル単位)')
print('・EAR (Eye Aspect Ratio): 目の開き具合(0~1、小さいほど閉じている)')
print()
print('【動き・頻度指標】')
print('・速度トレンド: 虹彩移動速度の変化率(px/s²)')
print('・瞬きトレンド: 瞬き頻度の変化率(回/分²)')
print('・サッケード頻度: 急速眼球運動の頻度(回/分)')
print('・動き頻度: 虹彩の動き開始回数(回/秒)')
print()
print('【追加指標】')
print('・スムースパシュート追跡: 滑らかな追従眼球運動の評価(0-100)')
print('・固視微動分析: 注視時の微小振動(振幅px、周波数Hz)')
print('・両眼視差: 左右眼の視線角度差(度)')
print('・虹彩安定性: 虹彩サイズの標準偏差(px)')
print()

# システム初期化
print('システム初期化中...')

# モデルダウンロード
if not os.path.exists(MODEL_PATH):
    print(f'{MODEL_INFO[MODEL_SIZE]["name"]}モデルをダウンロード中...')
    try:
        urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
        print('モデルのダウンロードが完了しました')
    except Exception as e:
        print(f'モデルのダウンロードに失敗しました: {e}')
        exit()

# MediaPipeモデル初期化(Face Meshを使用)
face_mesh = None
try:
    print(f'MediaPipe Face Mesh {MODEL_INFO[MODEL_SIZE]["name"]}モデルを初期化中...')
    mp_face_mesh = mp.solutions.face_mesh
    face_mesh = mp_face_mesh.FaceMesh(
        max_num_faces=10,
        refine_landmarks=True,  # 虹彩ランドマークを有効化
        min_detection_confidence=CONF_THRESH,
        min_tracking_confidence=CONF_THRESH,
        static_image_mode=False  # 動画モード
    )
    print(f'MediaPipe Face Mesh {MODEL_INFO[MODEL_SIZE]["name"]}モデルの初期化が完了しました')
    print(f'モデル: {MODEL_INFO[MODEL_SIZE]["name"]} ({MODEL_INFO[MODEL_SIZE]["desc"]})')
    print(f'検出可能: 顔ランドマーク478点(虹彩含む)')
except Exception as e:
    print('MediaPipe Face Meshモデルの初期化に失敗しました')
    print(f'エラー: {e}')
    exit()

print('CPUモード')
print('初期化完了')
print()

class KalmanFilter2D:
    """2次元カルマンフィルタ(x, y座標用)"""
    def __init__(self, process_noise=KALMAN_Q, measurement_noise=KALMAN_R):
        self.kf = cv2.KalmanFilter(4, 2)  # 状態4次元(x,y,vx,vy)、観測2次元(x,y)
        self.kf.transitionMatrix = np.array([[1, 0, 1, 0],
                                             [0, 1, 0, 1],
                                             [0, 0, 1, 0],
                                             [0, 0, 0, 1]], np.float32)
        self.kf.measurementMatrix = np.array([[1, 0, 0, 0],
                                              [0, 1, 0, 0]], np.float32)
        self.kf.processNoiseCov = np.eye(4, dtype=np.float32) * process_noise
        self.kf.measurementNoiseCov = np.eye(2, dtype=np.float32) * measurement_noise
        self.initialized = False

    def update(self, measurement):
        """測定値で更新"""
        if not self.initialized:
            self.kf.statePre = np.array([measurement[0], measurement[1], 0, 0], np.float32)
            self.kf.statePost = np.array([measurement[0], measurement[1], 0, 0], np.float32)
            self.initialized = True
        self.kf.predict()
        measurement_array = np.array([[measurement[0]], [measurement[1]]], np.float32)
        self.kf.correct(measurement_array)
        return self.kf.statePost[0], self.kf.statePost[1]


class HistoryManager:
    """履歴管理用の共通クラス"""
    def __init__(self, maxlen=300):
        self.data_history = deque(maxlen=maxlen)
        self.time_history = deque(maxlen=maxlen)

    def append(self, data, time):
        self.data_history.append(data)
        self.time_history.append(time)

    def get_histories(self):
        return self.data_history, self.time_history


def draw_japanese_text(image, text, position, font_size=FONT_SIZE, color=(255, 255, 255)):
    """日本語テキストを画像に描画(Pillow+OpenCV)"""
    try:
        font = ImageFont.truetype(FONT_PATH, font_size)
        img_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(img_pil)
        draw.text(position, text, font=font, fill=color)
        return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    except:
        cv2.putText(image, text, position, cv2.FONT_HERSHEY_SIMPLEX, font_size/30, color[::-1], 1)
        return image


def calculate_ear(eye_landmarks):
    """Eye Aspect Ratio (EAR)"""
    v1 = np.linalg.norm(eye_landmarks[1] - eye_landmarks[5])
    v2 = np.linalg.norm(eye_landmarks[2] - eye_landmarks[4])
    h = np.linalg.norm(eye_landmarks[0] - eye_landmarks[3])
    ear = (v1 + v2) / (2.0 * h) if h > 0 else 0
    return ear


def calculate_iris_size(iris_landmarks):
    """虹彩サイズ(直径相当):点間の最大距離"""
    max_dist = 0
    for i in range(len(iris_landmarks)):
        for j in range(i + 1, len(iris_landmarks)):
            dist = np.linalg.norm(iris_landmarks[i] - iris_landmarks[j])
            max_dist = max(max_dist, dist)
    return max_dist


def calculate_gaze_angle(iris_center, eye_center):
    """視線角度(画像座標上の偏角,度)と距離"""
    dx = iris_center[0] - eye_center[0]
    dy = iris_center[1] - eye_center[1]
    angle = math.degrees(math.atan2(dy, dx))
    distance = math.sqrt(dx**2 + dy**2)
    return angle, distance


def calculate_trend(data, time_data, window_seconds):
    """線形回帰でトレンド(傾き)を計算"""
    if len(data) < 2 or len(time_data) < 2:
        return None
    current_time = time_data[-1]
    valid = [(time_data[i], data[i]) for i in range(len(time_data))
             if current_time - time_data[i] <= window_seconds and data[i] is not None]
    if len(valid) < 2:
        return None
    t0 = valid[0][0]
    x = np.array([t - t0 for (t, _) in valid], dtype=np.float64)
    y = np.array([v for (_, v) in valid], dtype=np.float64)
    n = len(x)
    sum_x = np.sum(x); sum_y = np.sum(y)
    sum_xy = np.sum(x * y); sum_x2 = np.sum(x * x)
    denom = n * sum_x2 - sum_x * sum_x
    if abs(denom) < 1e-12:
        return 0.0
    return (n * sum_xy - sum_x * sum_y) / denom


def compute_eye_scale(landmarks_array):
    """基準距離C:外眼角間と内眼角間の距離の平均"""
    p33, p263 = landmarks_array[33], landmarks_array[263]
    p133, p362 = landmarks_array[133], landmarks_array[362]
    d_outer = np.linalg.norm(p263 - p33)  # 外眼角間距離
    d_inner = np.linalg.norm(p362 - p133)  # 内眼角間距離
    return max(1e-6, (d_outer + d_inner) / 2.0)


def calculate_smooth_pursuit_score(speed_norm_hist, time_hist, T=SMOOTH_PURSUIT_WINDOW,
                                   min_norm=None, max_norm=None):
    """スムースパシュートスコア(0-100)"""
    if len(speed_norm_hist) < 3 or len(time_hist) < 3:
        return 0.0
    t_now = time_hist[-1]
    idx = [i for i in range(len(time_hist)) if t_now - time_hist[i] <= T]
    if len(idx) < 3:
        return 0.0
    speeds = [speed_norm_hist[i] for i in idx]
    if min_norm is None or max_norm is None:
        return 0.0
    smooth_flags = [1 if (min_norm <= s <= max_norm) else 0 for s in speeds]
    proportion = sum(smooth_flags) / len(speeds)
    std_dev = float(np.std(speeds)) if len(speeds) > 1 else 0.0
    stability_score = max(0.0, 100.0 - 100.0 * min(1.0, std_dev))
    score = proportion * 100.0 * 0.7 + stability_score * 0.3
    return min(100.0, max(0.0, score))


def calculate_fixation_microsaccades(position_hist, time_hist, T=FIXATION_WINDOW):
    """固視微動の振幅(px)と周波数(Hz)"""
    if len(position_hist) < 3 or len(time_hist) < 3:
        return 0.0, 0.0
    t_now = time_hist[-1]
    idx = [i for i in range(len(time_hist)) if t_now - time_hist[i] <= T]
    if len(idx) < 3:
        return 0.0, 0.0
    positions = np.array([position_hist[i] for i in idx], dtype=np.float64)
    deviations = np.linalg.norm(positions - np.mean(positions, axis=0), axis=1)
    amplitude = float(np.mean(deviations)) if deviations.size > 0 else 0.0
    zero_crosses = 0
    mean_amp = amplitude
    for i in range(1, len(deviations)):
        if (deviations[i-1] - mean_amp) * (deviations[i] - mean_amp) < 0:
            zero_crosses += 1
    duration = max(1e-6, time_hist[idx[-1]] - time_hist[idx[0]])
    frequency = zero_crosses / (2.0 * duration)
    return amplitude, float(frequency)


def calculate_binocular_disparity(left_gaze_angle, right_gaze_angle):
    """両眼視差(度)"""
    return abs(left_gaze_angle - right_gaze_angle)


def calculate_iris_stability(size_hist):
    """虹彩安定性(サイズの標準偏差)"""
    if len(size_hist) < 10:
        return 0.0
    sizes = list(size_hist)[-30:]  # 最近30フレーム
    if len(sizes) < 2:
        return 0.0
    return float(np.std(sizes))


def draw_graph(frame, data, y_range, y_label, graph_x, graph_y, color=(0, 255, 0), dynamic_range=False):
    """グラフ描画"""
    cv2.rectangle(frame, (graph_x, graph_y), (graph_x + GRAPH_WIDTH, graph_y + GRAPH_HEIGHT), (50, 50, 50), -1)
    for i in range(5):
        y = graph_y + int(i * GRAPH_HEIGHT / 4)
        cv2.line(frame, (graph_x, y), (graph_x + GRAPH_WIDTH, y), (100, 100, 100), 1)
    recent_data = list(data[-int(GRAPH_TIME_WINDOW * FPS):]) if len(data) > 0 else []
    if dynamic_range and len(recent_data) > 0:
        valid_data = [d for d in recent_data if d is not None]
        if valid_data:
            dmin, dmax = min(valid_data), max(valid_data)
            margin = (dmax - dmin) * 0.1 if dmax != dmin else 1.0
            y_range = (dmin - margin, dmax + margin)
    if len(recent_data) > 1:
        points = []
        for i, value in enumerate(recent_data):
            if value is not None:
                x = graph_x + int(i * GRAPH_WIDTH / len(recent_data))
                y = graph_y + GRAPH_HEIGHT - int((value - y_range[0]) / (y_range[1] - y_range[0] + 1e-12) * GRAPH_HEIGHT)
                y = max(graph_y, min(graph_y + GRAPH_HEIGHT, y))
                points.append((x, y))
        for i in range(1, len(points)):
            cv2.line(frame, points[i-1], points[i], color, 2)
    frame = draw_japanese_text(frame, y_label, (graph_x + 5, graph_y + 5), FONT_SIZE, (255, 255, 255))
    cv2.putText(frame, f'{y_range[1]:.1f}', (graph_x + GRAPH_WIDTH - 50, graph_y + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
    cv2.putText(frame, f'{y_range[0]:.1f}', (graph_x + GRAPH_WIDTH - 50, graph_y + GRAPH_HEIGHT - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)


def calculate_frequency(event_times, current_time, window):
    """汎用的な頻度計算関数"""
    return len([t for t in event_times if current_time - t < window])


def trim_graph_data_by_time():
    """時間窓でグラフ用データをトリム"""
    if len(graph_data['time']) < 2:
        return
    while graph_data['time'][-1] - graph_data['time'][0] > GRAPH_TIME_WINDOW and len(graph_data['time']) > 1:
        for key in ['time', 'gaze_angle', 'speed', 'speed_trend', 'iris_size', 'blink_freq', 'blink_trend', 'movement_freq', 'saccade_freq']:
            if len(graph_data[key]) > 0:
                graph_data[key].popleft()


def draw_iris_circle(frame, center, color, inner_radius=5, outer_radius=15):
    """虹彩円を描画する共通関数"""
    cv2.circle(frame, tuple(center.astype(int)), inner_radius, color, -1)
    cv2.circle(frame, tuple(center.astype(int)), outer_radius, color, 2)


# グローバル変数
frame_count = 0
results_log = []
previous_positions = {}
previous_raw_positions = {}
blink_states = {}
kalman_filters = {}
saccade_states = {}
saccade_events = []
movement_states = {}
smooth_pursuit_states = {}
fixation_states = {}
iris_size_history = {}

# グラフ用データ保存
graph_data = {
    'time': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'gaze_angle': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'speed': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'speed_trend': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'iris_size': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'blink_times': deque(maxlen=1000),
    'blink_freq': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'blink_trend': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'movement_freq': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'saccade_freq': deque(maxlen=int(GRAPH_TIME_WINDOW * FPS)),
    'saccade_times': deque(maxlen=1000),
}

def video_frame_processing(frame):
    """フレーム処理メイン関数"""
    global frame_count
    current_time = time.time()
    frame_count += 1

    # 入力ソースに応じた現在時刻(秒)
    if choice == '0' or choice == '2':  # 動画ファイル
        current_timestamp = frame_count / 30.0  # 仮定FPS
    else:  # カメラ
        current_timestamp = time.monotonic()

    # 描画用拡張キャンバス
    extended_width = frame.shape[1] + GRAPH_WIDTH + GRAPH_MARGIN * 2
    extended_frame = np.zeros((frame.shape[0], extended_width, 3), dtype=np.uint8)
    extended_frame[:, :frame.shape[1]] = frame

    # RGB変換とMediaPipe処理
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = face_mesh.process(rgb_frame)

    result_text = ""

    if results.multi_face_landmarks:
        height, width = frame.shape[:2]
        for face_idx, face_landmarks in enumerate(results.multi_face_landmarks):
            if face_idx > 0:
                continue

            # ランドマーク数の検証
            landmark_count = len(face_landmarks.landmark)
            if landmark_count == 478:
                # 虹彩ランドマーク利用可能
                pass
            elif landmark_count == 468:
                print(f"警告: 虹彩ランドマークなし(Face Meshモード)- {landmark_count}個のランドマーク")
                continue
            else:
                print(f"予期しないランドマーク数: {landmark_count}")
                continue

            landmarks_array = np.array([(lm.x * width, lm.y * height) for lm in face_landmarks.landmark])

            # スケール
            scale_eye = compute_eye_scale(landmarks_array)
            saccade_thresh_norm = SACCADE_THRESH / scale_eye
            movement_thresh_norm = movement_threshold / scale_eye
            smooth_min_norm = SMOOTH_PURSUIT_MIN_SPEED / scale_eye
            smooth_max_norm = SMOOTH_PURSUIT_MAX_SPEED / scale_eye

            # EAR
            left_eye_points = landmarks_array[LEFT_EYE_INDICES]
            right_eye_points = landmarks_array[RIGHT_EYE_INDICES]
            left_ear = calculate_ear(left_eye_points)
            right_ear = calculate_ear(right_eye_points)
            avg_ear = (left_ear + right_ear) / 2.0

            # 瞬き状態管理
            face_key = f"face_{face_idx}"
            if face_key not in blink_states:
                blink_states[face_key] = {'is_blinking': False, 'ear_blink': False, 'iris_blink': False}

            ear_blink_detected = False
            if avg_ear < EAR_THRESH and not blink_states[face_key]['ear_blink']:
                blink_states[face_key]['ear_blink'] = True
                ear_blink_detected = True
                if face_idx == 0:
                    graph_data['blink_times'].append(current_timestamp)
            elif avg_ear >= EAR_THRESH:
                blink_states[face_key]['ear_blink'] = False

            # 虹彩検出と計算
            iris_detected = True
            try:
                left_iris_points = landmarks_array[LEFT_IRIS_INDICES]
                right_iris_points = landmarks_array[RIGHT_IRIS_INDICES]
                left_iris_center_raw = np.mean(left_iris_points, axis=0)
                right_iris_center_raw = np.mean(right_iris_points, axis=0)
                left_iris_size_px = calculate_iris_size(left_iris_points)
                right_iris_size_px = calculate_iris_size(right_iris_points)

                # カルマンフィルタ
                if face_key not in kalman_filters:
                    kalman_filters[face_key] = {'left': KalmanFilter2D(), 'right': KalmanFilter2D()}
                left_iris_filtered = kalman_filters[face_key]['left'].update(left_iris_center_raw)
                right_iris_filtered = kalman_filters[face_key]['right'].update(right_iris_center_raw)
                left_iris_center = np.array(left_iris_filtered)
                right_iris_center = np.array(right_iris_filtered)

                # 目の中心と視線角
                left_eye_center = np.mean(left_eye_points, axis=0)
                right_eye_center = np.mean(right_eye_points, axis=0)
                left_gaze_angle, _ = calculate_gaze_angle(left_iris_center, left_eye_center)
                right_gaze_angle, _ = calculate_gaze_angle(right_iris_center, right_eye_center)
                avg_gaze_angle = (left_gaze_angle + right_gaze_angle) / 2.0

            except:
                iris_detected = False

            if iris_detected:
                # 速度計算
                left_speed_raw = 0.0
                right_speed_raw = 0.0
                if face_key in previous_raw_positions:
                    dt = current_timestamp - previous_raw_positions[face_key]['timestamp']
                    if dt > 0:
                        left_speed_raw = float(np.linalg.norm(left_iris_center_raw - previous_raw_positions[face_key]['left']) / dt)
                        right_speed_raw = float(np.linalg.norm(right_iris_center_raw - previous_raw_positions[face_key]['right']) / dt)
                avg_speed_raw = (left_speed_raw + right_speed_raw) / 2.0
                avg_speed_norm = avg_speed_raw / scale_eye

                # 動き開始検出
                if face_key not in movement_states:
                    movement_states[face_key] = {'in_movement': False, 'movement_start_times': deque(maxlen=100)}
                if avg_speed_norm > movement_thresh_norm and not movement_states[face_key]['in_movement']:
                    movement_states[face_key]['in_movement'] = True
                    movement_states[face_key]['movement_start_times'].append(current_timestamp)
                elif avg_speed_norm <= movement_thresh_norm:
                    movement_states[face_key]['in_movement'] = False

                # サッケード検出
                if face_key not in saccade_states:
                    saccade_states[face_key] = {'in_saccade': False, 'saccade_start': 0.0, 'max_speed_px': 0.0}
                if avg_speed_norm > saccade_thresh_norm and not saccade_states[face_key]['in_saccade']:
                    saccade_states[face_key]['in_saccade'] = True
                    saccade_states[face_key]['saccade_start'] = current_timestamp
                    saccade_states[face_key]['max_speed_px'] = avg_speed_raw
                elif saccade_states[face_key]['in_saccade']:
                    saccade_states[face_key]['max_speed_px'] = max(saccade_states[face_key]['max_speed_px'], avg_speed_raw)
                    if avg_speed_norm < saccade_thresh_norm:
                        duration = current_timestamp - saccade_states[face_key]['saccade_start']
                        if duration >= SACCADE_MIN_DURATION:
                            if face_idx == 0:
                                graph_data['saccade_times'].append(current_timestamp)
                            saccade_events.append({
                                'face_id': face_idx,
                                'timestamp': current_timestamp,
                                'duration': duration,
                                'max_speed': saccade_states[face_key]['max_speed_px']
                            })
                        saccade_states[face_key]['in_saccade'] = False

                # スムースパシュート(履歴管理クラス使用)
                if face_key not in smooth_pursuit_states:
                    smooth_pursuit_states[face_key] = HistoryManager(maxlen=300)
                sph = smooth_pursuit_states[face_key]
                sph.append(avg_speed_norm, current_timestamp)
                speed_hist, time_hist = sph.get_histories()
                smooth_pursuit_score = calculate_smooth_pursuit_score(
                    speed_hist, time_hist,
                    T=SMOOTH_PURSUIT_WINDOW, min_norm=smooth_min_norm, max_norm=smooth_max_norm
                )

                # 固視微動(履歴管理クラス使用)
                if face_key not in fixation_states:
                    fixation_states[face_key] = HistoryManager(maxlen=300)
                fx = fixation_states[face_key]
                avg_position = (left_iris_center_raw + right_iris_center_raw) / 2.0
                fx.append(avg_position, current_timestamp)
                pos_hist, time_hist = fx.get_histories()
                fixation_amplitude, fixation_frequency = calculate_fixation_microsaccades(
                    pos_hist, time_hist, T=FIXATION_WINDOW
                )

                # 虹彩安定性(履歴管理クラス使用)
                if face_key not in iris_size_history:
                    iris_size_history[face_key] = HistoryManager(maxlen=300)
                ir = iris_size_history[face_key]
                avg_iris_size_px = (left_iris_size_px + right_iris_size_px) / 2.0
                ir.append(avg_iris_size_px, current_timestamp)
                size_hist, _ = ir.get_histories()
                iris_stability = calculate_iris_stability(size_hist)

                # 両眼視差
                binocular_disparity = calculate_binocular_disparity(left_gaze_angle, right_gaze_angle)

                # 前回位置の保存
                previous_raw_positions[face_key] = {'left': left_iris_center_raw.copy(),
                                                    'right': right_iris_center_raw.copy(),
                                                    'timestamp': current_timestamp}
                previous_positions[face_key] = {'left': left_iris_center.copy(),
                                                'right': right_iris_center.copy(),
                                                'timestamp': current_timestamp}

                # グラフ用データ
                if face_idx == 0:
                    graph_data['time'].append(current_timestamp)
                    graph_data['gaze_angle'].append(avg_gaze_angle)
                    graph_data['speed'].append(avg_speed_raw)
                    graph_data['iris_size'].append(avg_iris_size_px)

                    blink_freq = calculate_frequency(graph_data['blink_times'], current_timestamp, 60)
                    graph_data['blink_freq'].append(blink_freq)
                    saccade_freq = calculate_frequency(graph_data['saccade_times'], current_timestamp, 60)
                    graph_data['saccade_freq'].append(saccade_freq)
                    movement_freq = calculate_frequency(movement_states[face_key]['movement_start_times'], current_timestamp, 1)
                    graph_data['movement_freq'].append(movement_freq)

                    speed_trend = calculate_trend(list(graph_data['speed']), list(graph_data['time']), TREND_WINDOW)
                    graph_data['speed_trend'].append(0.0 if speed_trend is None else speed_trend)
                    blink_trend = calculate_trend(list(graph_data['blink_freq']), list(graph_data['time']), TREND_WINDOW)
                    graph_data['blink_trend'].append(0.0 if blink_trend is None else blink_trend)

                    trim_graph_data_by_time()

                # 結果テキスト
                result_text = (f'視線:{avg_gaze_angle:.1f}度 | '
                              f'虹彩:左{left_iris_size_px:.1f}px,右{right_iris_size_px:.1f}px | '
                              f'速度:{avg_speed_raw:.1f}px/s | '
                              f'サッケード:{saccade_freq:.0f}回/分 | '
                              f'スムース:{smooth_pursuit_score:.1f} | '
                              f'固視:{fixation_amplitude:.1f}px,{fixation_frequency:.1f}Hz | '
                              f'視差:{binocular_disparity:.1f}度 | '
                              f'安定性:{iris_stability:.1f}px')

                # 虹彩円描画(共通関数使用)
                draw_iris_circle(extended_frame, left_iris_center, (0, 255, 0))
                draw_iris_circle(extended_frame, right_iris_center, (0, 0, 255))

                label = f'顔 {face_idx+1}'
                extended_frame = draw_japanese_text(extended_frame, label,
                                                   (int(left_iris_center[0])-30, int(left_iris_center[1])-50),
                                                   20, (255, 255, 255))
                gaze_label = f'視線: {avg_gaze_angle:.1f}度'
                extended_frame = draw_japanese_text(extended_frame, gaze_label,
                                                   (int(left_iris_center[0])-30, int(left_iris_center[1])-25),
                                                   16, (255, 255, 0))

    # グラフ描画(設定リスト化)
    graph_x_start = frame.shape[1] + GRAPH_MARGIN
    total_graph_height = frame.shape[0] - 2 * GRAPH_MARGIN
    single_graph_height = (total_graph_height - 4 * GRAPH_MARGIN) // 5

    graph_configs = [
        (list(graph_data['gaze_angle']), (-180, 180), "視線角度 (度)", (0, 255, 255), False),
        (list(graph_data['speed_trend']), (-10, 10), "速度トレンド (px/s²)", (0, 255, 0), False),
        (list(graph_data['blink_trend']), (-5, 5), "瞬き頻度トレンド (回/分²)", (255, 0, 255), False),
        (list(graph_data['movement_freq']), (0, 10), "動き頻度 (回/秒)", (255, 255, 0), True),
        (list(graph_data['saccade_freq']), (0, 120), "サッケード頻度 (回/分)", (255, 128, 0), False)
    ]

    y = GRAPH_MARGIN
    for data, y_range, label, color, dynamic_range in graph_configs:
        draw_graph(extended_frame, data, y_range, label, graph_x_start, y, color, dynamic_range)
        y += single_graph_height + GRAPH_MARGIN

    # システム情報
    info1 = f'MediaPipe (CPU) | フレーム: {frame_count}'
    info2 = '操作: q=終了'
    extended_frame = draw_japanese_text(extended_frame, info1, (10, 10), 20, (255, 255, 255))
    extended_frame = draw_japanese_text(extended_frame, info2, (10, 35), 16, (255, 255, 0))

    return extended_frame, result_text, current_time

# 入力選択
print('0: 動画ファイル')
print('1: カメラ')
print('2: サンプル動画')

choice = input('選択: ')

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename()
    if not path:
        exit()
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    if not cap.isOpened():
        cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
else:
    # サンプル動画ダウンロード・処理
    SAMPLE_URL = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'
    SAMPLE_FILE = 'vtest.avi'
    urllib.request.urlretrieve(SAMPLE_URL, SAMPLE_FILE)
    cap = cv2.VideoCapture(SAMPLE_FILE)

if not cap.isOpened():
    print('動画ファイル・カメラを開けませんでした')
    exit()

# メイン処理
print('\n=== 動画処理開始 ===')
print('操作方法:')
print('  q キー: プログラム終了')
try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        MAIN_FUNC_DESC = "MediaPipe虹彩追跡"
        processed_frame, result, current_time = video_frame_processing(frame)
        cv2.imshow(MAIN_FUNC_DESC, processed_frame)
        if choice == '1':  # カメラの場合
            print(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], result)
        else:  # 動画ファイルの場合
            print(frame_count, result)
        results_log.append(result)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
finally:
    print('\n=== プログラム終了 ===')
    cap.release()
    cv2.destroyAllWindows()
    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('=== 結果 ===\n')
            f.write(f'処理フレーム数: {frame_count}\n')
            f.write(f'使用デバイス: CPU\n')
            f.write(f'総サッケード検出数: {len(saccade_events)}\n')
            if saccade_events:
                avg_duration = sum(s['duration'] for s in saccade_events) / len(saccade_events) * 1000
                avg_speed = sum(s['max_speed'] for s in saccade_events) / len(saccade_events)
                f.write(f'平均サッケード持続時間: {avg_duration:.1f}ms\n')
                f.write(f'平均サッケード最大速度: {avg_speed:.1f}px/s\n')
            f.write('\n')
            f.write('\n'.join(results_log))
        print(f'\n処理結果をresult.txtに保存しました')