PiDiNet(Pixel Difference Network)による動画エッジ検出と異常検出(ソースコードと実行結果)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install opencv-python numpy pillow

PiDiNet(Pixel Difference Network)による動画エッジ検出と異常検出プログラム

ソースコード


# PiDiNet(Pixel Difference Network)による動画エッジ検出と異常検出プログラム
# 特徴技術名: PiDiNet(Pixel Difference Network)
# 出典: Su, Z., Liu, W., Yu, Z., Hu, D., Liao, Q., Tian, Q., Pietikainen, M., & Liu, L. (2021).
#       Pixel Difference Networks for Efficient Edge Detection.
#       Proceedings of IEEE International Conference on Computer Vision (ICCV). arXiv:2108.07009
# 特徴機能: Pixel Difference Convolution (PDC)を用いた軽量エッジ検出。
#          BSDS500データセットでODS F-score 0.807を達成、96FPSの処理速度で動作
# 学習済みモデル: PiDiNet公式モデル(table5_pidinet.pth)、BSDS500データセットで学習済み
#                GitHubから自動ダウンロード対応
# 方式設計:
#   関連利用技術: PyTorch 公式PiDiNetアーキテクチャ実装、OpenCV 動画処理・表示、
#                NumPy 数値計算、Pillow 日本語テキスト表示、tkinter ファイル選択
#   入力と出力: 入力: 動画(ファイル/カメラ/サンプル)、
#              出力: マルチスケールエッジ検出結果のリアルタイム表示、異常検出結果の記録
#   処理手順: 1.公式重みダウンロード(GitHub raw URL使用)、2.モデル構築(dilations/attentions/conv_reduces構造)、
#            3.フレーム前処理(ImageNet正規化)、4.エッジ検出推論(4スケール出力)、
#            5.適応的閾値による異常検出、6.結果表示・保存
#   前処理、後処理: 前処理: 画像リサイズ(512x512)・RGB変換・正規化(ImageNet/PiDi/None選択可)、
#                  後処理: Sigmoid活性化・閾値処理・輪郭検出・異常分類(障害物/亀裂/ひび/汚れ/細部異常)
#   追加処理: 適応的閾値調整システム(平均強度ベースの自動調整)、マルチスケール異常検出(4段階)、
#            検出結果の色分け表示、DataParallelラップ対応
#   調整を必要とする設定値: 初期閾値(0.3, 0.35, 0.4, 0.45)、正規化モード(NORM_MODE環境変数)、
#                        強制再ダウンロード(FORCE_REDWN環境変数)
# 前準備
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
# pip install opencv-python numpy pillow

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tkinter as tk
from tkinter import filedialog
import urllib.request
import os
import time
import hashlib
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont

# ========== デバイス選択 ==========
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')

# ========== 設定 ==========
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
FONT_SIZE = 20

# 公式重みの保存先(既存ファイルと混同しないため別名)
OFFICIAL_MODEL_PATH = 'table5_pidinet_official.pth'
# 修正:正しいGitHub rawファイルのURL形式
OFFICIAL_DOWNLOAD_URL = 'https://github.com/hellozhuo/pidinet/raw/master/trained_models/table5_pidinet.pth'

# 正規化設定(imagenet/pidi/none を環境変数で切替可能)
NORM_MODE = os.getenv('NORM_MODE', 'imagenet').lower()

# 既存ファイルを強制再ダウンロードするか(0/1)
FORCE_REDWN = os.getenv('FORCE_REDWN', '0') == '1'

# ========== 正規化プロファイル ==========
def get_normalize_tensors(mode: str, device):
    mode = mode.lower()
    if mode == 'imagenet':
        mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32, device=device).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32, device=device).view(3, 1, 1)
    elif mode == 'pidi':
        mean = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device=device).view(3, 1, 1)
        std = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device=device).view(3, 1, 1)
    else:
        mean = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device=device).view(3, 1, 1)
        std = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float32, device=device).view(3, 1, 1)
    return mean, std

# ========== 公式PiDiNet実装(table5_pidinet準拠) ==========
class DilationModule(nn.Module):
    """Dilation Module (公式実装のdilations)"""
    def __init__(self, in_channels):
        super(DilationModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 24, kernel_size=1, bias=True)
        self.conv2_1 = nn.Conv2d(24, 24, kernel_size=3, padding=1, dilation=1, bias=False)
        self.conv2_2 = nn.Conv2d(24, 24, kernel_size=3, padding=2, dilation=2, bias=False)
        self.conv2_3 = nn.Conv2d(24, 24, kernel_size=3, padding=3, dilation=3, bias=False)
        self.conv2_4 = nn.Conv2d(24, 24, kernel_size=3, padding=4, dilation=4, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        out1 = self.conv2_1(x)
        out2 = self.conv2_2(x)
        out3 = self.conv2_3(x)
        out4 = self.conv2_4(x)
        return out1 + out2 + out3 + out4

class AttentionModule(nn.Module):
    """Attention Module (公式実装のattentions)"""
    def __init__(self, in_channels=24):
        super(AttentionModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 4, kernel_size=1, bias=True)
        self.conv2 = nn.Conv2d(4, 1, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        att = self.conv1(x)
        att = self.conv2(att)
        att = torch.sigmoid(att)
        return x * att

class ConvReduceModule(nn.Module):
    """Conv Reduce Module (公式実装のconv_reduces)"""
    def __init__(self):
        super(ConvReduceModule, self).__init__()
        self.conv = nn.Conv2d(24, 1, kernel_size=1, bias=True)

    def forward(self, x):
        return self.conv(x)

class PiDiBlock(nn.Module):
    """PiDi Block (公式実装準拠)"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(PiDiBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride,
                               padding=1, bias=False, groups=in_channels)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1,
                               padding=0, bias=False)
        self.shortcut = None
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                                     stride=stride, padding=0, bias=True)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.shortcut is not None:
            residual = self.shortcut(x)
        out += residual
        return out

class OfficialPiDiNet(nn.Module):
    """公式PiDiNet実装(table5_pidinet.pth準拠)"""
    def __init__(self):
        super(OfficialPiDiNet, self).__init__()
        # Initial block
        self.init_block = nn.Conv2d(3, 60, kernel_size=3, padding=1, bias=False)

        # Stage 1: 60 channels, 3 blocks
        self.block1_1 = PiDiBlock(60, 60)
        self.block1_2 = PiDiBlock(60, 60)
        self.block1_3 = PiDiBlock(60, 60)

        # Stage 2: 120 channels, 4 blocks
        self.block2_1 = PiDiBlock(60, 120, stride=2)
        self.block2_2 = PiDiBlock(120, 120)
        self.block2_3 = PiDiBlock(120, 120)
        self.block2_4 = PiDiBlock(120, 120)

        # Stage 3: 240 channels, 4 blocks
        self.block3_1 = PiDiBlock(120, 240, stride=2)
        self.block3_2 = PiDiBlock(240, 240)
        self.block3_3 = PiDiBlock(240, 240)
        self.block3_4 = PiDiBlock(240, 240)

        # Stage 4: 240 channels, 4 blocks
        self.block4_1 = PiDiBlock(240, 240, stride=2)
        self.block4_2 = PiDiBlock(240, 240)
        self.block4_3 = PiDiBlock(240, 240)
        self.block4_4 = PiDiBlock(240, 240)

        # Side branches - 公式実装の名前に合わせる
        self.dilations = nn.ModuleList([
            DilationModule(60),
            DilationModule(120),
            DilationModule(240),
            DilationModule(240),
        ])

        self.attentions = nn.ModuleList([
            AttentionModule(24) for _ in range(4)
        ])

        self.conv_reduces = nn.ModuleList([
            ConvReduceModule() for _ in range(4)
        ])

        # Fusion classifier
        self.classifier = nn.Conv2d(4, 1, kernel_size=1, bias=True)

    def forward(self, x):
        # Backbone forward
        x0 = self.init_block(x)

        # Stage 1
        x1 = self.block1_1(x0)
        x1 = self.block1_2(x1)
        x1 = self.block1_3(x1)

        # Stage 2
        x2 = self.block2_1(x1)
        x2 = self.block2_2(x2)
        x2 = self.block2_3(x2)
        x2 = self.block2_4(x2)

        # Stage 3
        x3 = self.block3_1(x2)
        x3 = self.block3_2(x3)
        x3 = self.block3_3(x3)
        x3 = self.block3_4(x3)

        # Stage 4
        x4 = self.block4_1(x3)
        x4 = self.block4_2(x4)
        x4 = self.block4_3(x4)
        x4 = self.block4_4(x4)

        # Side outputs
        features = [x1, x2, x3, x4]
        edge_outputs = []

        for i, (feature, dilation, attention, conv_reduce) in enumerate(
            zip(features, self.dilations, self.attentions, self.conv_reduces)):
            dilated = dilation(feature)
            attended = attention(dilated)
            edge = conv_reduce(attended)

            # Upsample to match first scale
            if i > 0:
                edge = F.interpolate(edge, size=edge_outputs[0].shape[2:],
                                   mode='bilinear', align_corners=False)
            edge_outputs.append(edge)

        # Fusion
        fused = torch.cat(edge_outputs, dim=1)
        final_output = self.classifier(fused)

        return final_output, edge_outputs[0], edge_outputs[1], edge_outputs[2], edge_outputs[3]

# ========== ユーティリティ ==========
def sha256sum(path):
    try:
        h = hashlib.sha256()
        with open(path, 'rb') as f:
            for chunk in iter(lambda: f.read(1 << 20), b''):
                h.update(chunk)
        return h.hexdigest()
    except Exception:
        return None

def summarize_keys(sd, max_items=15):
    if not isinstance(sd, dict):
        return []
    keys = list(sd.keys())
    return keys[:max_items]

def _extract_state_dict(ckpt):
    if isinstance(ckpt, dict) and 'state_dict' in ckpt and isinstance(ckpt['state_dict'], dict):
        return ckpt['state_dict']
    return ckpt if isinstance(ckpt, dict) else {}

def _compatibility_score(model_sd, ckpt_sd):
    if not isinstance(ckpt_sd, dict) or len(ckpt_sd) == 0:
        return 0.0
    total = len(model_sd)
    matched = 0
    for k, w in model_sd.items():
        if k in ckpt_sd and tuple(ckpt_sd[k].shape) == tuple(w.shape):
            matched += 1
    return matched / max(total, 1)

def diagnose_checkpoint(file_path, ckpt_sd):
    print('=== チェックポイント診断開始 ===')
    print(f'ファイル: {file_path}')
    try:
        size = os.path.getsize(file_path)
    except Exception:
        size = -1
    print(f'ファイルサイズ: {size} bytes')
    print(f'SHA256: {sha256sum(file_path)}')
    keys = list(ckpt_sd.keys())
    print(f'キー数: {len(keys)}')
    print(f'キー例(先頭): {summarize_keys(ckpt_sd, 20)}')

    # キーパターンの確認
    if len(keys) > 0:
        if any('module.' in k for k in keys):
            print('診断結果: DataParallelでラップされたモデルの重みファイルである可能性があります。')
            print('対処: module.プレフィックスを除去する必要があります。')
            return 'dataparallel'
        elif not any('block' in k or 'head' in k or 'cdcm' in k or 'csam' in k for k in keys):
            print('診断結果: PiDiNetの重み構造と一致しません。')
            return 'incompatible'

    print('診断結果: 正常なPiDiNet重みファイルの可能性があります。')
    print('=== 診断終了 ===')
    return 'normal'

# ========== モデル管理 ==========
class EdgeModelWrapper(nn.Module):
    def __init__(self, official_weight_path):
        super().__init__()
        self.model = OfficialPiDiNet().to(device)
        self.model_name = 'OfficialPiDiNet'
        self._load_weights(official_weight_path)

    def _load_weights(self, official_weight_path):
        if not official_weight_path or not os.path.exists(official_weight_path):
            raise SystemExit('公式重みファイルが見つからない: ' + str(official_weight_path))

        # ロード
        try:
            try:
                ckpt = torch.load(official_weight_path, map_location='cpu', weights_only=True)
            except TypeError:
                ckpt = torch.load(official_weight_path, map_location='cpu')
        except Exception as e:
            raise SystemExit(f'重みロードに失敗: {e}')

        ckpt_sd = _extract_state_dict(ckpt)

        # 診断
        result = diagnose_checkpoint(official_weight_path, ckpt_sd)

        # DataParallelの場合の処理
        if result == 'dataparallel':
            new_state_dict = {}
            for k, v in ckpt_sd.items():
                name = k[7:] if k.startswith('module.') else k
                new_state_dict[name] = v
            ckpt_sd = new_state_dict
            print('module.プレフィックスを除去しました。')

        # 互換性評価
        score = _compatibility_score(self.model.state_dict(), ckpt_sd)
        print(f'互換スコア(名前・形状一致率): {score:.3f}')

        # 重みの適用(strict=Falseで部分的な適用を許可)
        try:
            missing_keys, unexpected_keys = self.model.load_state_dict(ckpt_sd, strict=False)
            if missing_keys:
                print(f'警告: 以下のキーが重みファイルに存在しません: {missing_keys[:5]}...')
            if unexpected_keys:
                print(f'警告: 以下のキーはモデルに存在しません: {unexpected_keys[:5]}...')
            if not missing_keys and not unexpected_keys:
                print('公式重みを完全に適用しました。')
            else:
                print('公式重みを部分的に適用しました。')
        except Exception as e:
            print(f'警告: 重みの適用に問題がありました: {e}')
            print('デフォルトの初期化重みを使用します。')

    def forward(self, x):
        return self.model(x)

# ========== 適応的閾値 ==========
class AdaptiveThresholdSystem:
    def __init__(self):
        self.thresholds = [0.3, 0.35, 0.4, 0.45]
        self.edge_count_history = []
        self.scale_statistics = [[] for _ in range(4)]
        print("適応的閾値システム初期化完了")
        print(f"初期閾値(スケール別): {self.thresholds}")

    def update_and_adjust(self, edge_maps, total_contours):
        self.edge_count_history.append(total_contours)
        for i, edge_map in enumerate(edge_maps):
            mean_intensity = float(np.mean(edge_map))
            self.scale_statistics[i].append(mean_intensity)
            if len(self.scale_statistics[i]) > 20:
                self.scale_statistics[i].pop(0)
        if len(self.edge_count_history) > 10:
            self.edge_count_history.pop(0)
        if len(self.edge_count_history) == 10:
            for i in range(4):
                if len(self.scale_statistics[i]) >= 10:
                    scale_mean = float(np.mean(self.scale_statistics[i]))
                    target_mean = 0.2 if i < 2 else 0.15
                    # 修正:閾値調整ロジックを逆転
                    if scale_mean > target_mean * 1.3:
                        # 平均強度が高い場合は閾値を下げる(検出を抑制)
                        self.thresholds[i] = max(0.15, self.thresholds[i] - 0.03)
                    elif scale_mean < target_mean * 0.7:
                        # 平均強度が低い場合は閾値を上げる(検出を増やす)
                        self.thresholds[i] = min(0.8, self.thresholds[i] + 0.03)
            avg_count = float(np.mean(self.edge_count_history))
            if avg_count < 10 or avg_count > 100:
                print(f"閾値調整: {[f'{t:.2f}' for t in self.thresholds]} (検出数: {avg_count:.0f})")
        return self.thresholds

# ========== 幾何ユーティリティ ==========
def is_linear_obstacle(contour, image_shape):
    img_diagonal = float(np.sqrt(image_shape[0]**2 + image_shape[1]**2))
    min_length = img_diagonal * 0.05
    perimeter = cv2.arcLength(contour, True)
    if perimeter < min_length:
        return False
    rect = cv2.minAreaRect(contour)
    width, height = rect[1]
    if min(width, height) > 0:
        aspect_ratio = max(width, height) / min(width, height)
        return aspect_ratio > 8.0
    return False

# ========== マルチスケール異常検出 ==========
def detect_anomalies_multiscale(edge_maps, thresholds, image_shape):
    final_edge, edge1, edge2, edge3, edge4 = edge_maps
    scale_edges = [edge1, edge2, edge3, edge4]
    anomalies = []
    total_contours = 0
    colored_frame = np.zeros((final_edge.shape[0], final_edge.shape[1], 3), dtype=np.uint8)
    for scale_idx, (edge_map, threshold) in enumerate(zip(scale_edges, thresholds)):
        binary_edge = (edge_map > threshold).astype(np.uint8) * 255
        if scale_idx < 2:
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
        else:
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        binary_edge = cv2.morphologyEx(binary_edge, cv2.MORPH_CLOSE, kernel)
        contours, _ = cv2.findContours(binary_edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        total_contours += len(contours)
        for contour in contours:
            area = cv2.contourArea(contour)
            min_area = 50 + (scale_idx * 25)
            if area < min_area:
                continue
            if is_linear_obstacle(contour, image_shape):
                anomaly_type = '障害物'
                color = (0, 0, 255)
            else:
                if scale_idx == 0:
                    anomaly_type = '細部異常'
                    color = (255, 0, 0)
                elif scale_idx == 1:
                    anomaly_type = 'ひび'
                    color = (0, 255, 0)
                elif scale_idx == 2:
                    anomaly_type = '汚れ'
                    color = (0, 255, 255)
                else:
                    anomaly_type = '亀裂'
                    color = (0, 165, 255)
            M = cv2.moments(contour)
            if M['m00'] != 0:
                cx = int(M['m10'] / M['m00'])
                cy = int(M['m01'] / M['m00'])
                anomalies.append({'type': anomaly_type, 'location': (cx, cy), 'scale': scale_idx + 1, 'area': area})
                cv2.drawContours(colored_frame, [contour], -1, color, 2)
                cv2.circle(colored_frame, (cx, cy), 5, color, -1)
    return anomalies, colored_frame, total_contours

# ========== フォント ==========
def initialize_font():
    try:
        if os.path.exists(FONT_PATH):
            font = ImageFont.truetype(FONT_PATH, FONT_SIZE)
            print("日本語フォント読み込み成功")
        else:
            print("日本語フォントが見つかりません。デフォルトフォントを使用します。")
            font = ImageFont.load_default()
    except Exception as e:
        print(f"フォント読み込みエラー: {e}")
        font = ImageFont.load_default()
    return font

# ========== モデルファイルDL ==========
def download_official_weight(dst_path=OFFICIAL_MODEL_PATH, url=OFFICIAL_DOWNLOAD_URL, force=False):
    if os.path.exists(dst_path) and not force:
        print('既存の公式モデルファイルを使用します')
        return dst_path
    print('PiDiNet公式重みをダウンロード中...')
    try:
        urllib.request.urlretrieve(url, dst_path)
        print('ダウンロード完了')
    except Exception as e:
        print(f'公式モデルダウンロードエラー: {e}')
    return dst_path

# ========== 入力前処理 ==========
def preprocess_frame(frame_bgr, device, norm_mode='imagenet', input_size=(512, 512)):
    mean, std = get_normalize_tensors(norm_mode, device)
    img = cv2.resize(frame_bgr, input_size)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    ten = torch.from_numpy(img.transpose(2, 0, 1)).float().to(device) / 255.0
    ten = (ten - mean) / std
    ten = ten.unsqueeze(0)
    return ten

# ========== グローバル ==========
frame_count = 0
results_log = []
model_wrapper = None
adaptive_threshold_system = None
global_font = None
norm_mode_runtime = NORM_MODE

# ========== メイン処理 ==========
def video_frame_processing(frame):
    global frame_count, model_wrapper, adaptive_threshold_system, global_font
    current_time = time.time()
    frame_count += 1
    input_tensor = preprocess_frame(frame, device, norm_mode=norm_mode_runtime, input_size=(512, 512))
    with torch.no_grad():
        outputs = model_wrapper(input_tensor)
        fused = torch.sigmoid(outputs[0]).squeeze().cpu().numpy()
        s1 = torch.sigmoid(outputs[1]).squeeze().cpu().numpy()
        s2 = torch.sigmoid(outputs[2]).squeeze().cpu().numpy()
        s3 = torch.sigmoid(outputs[3]).squeeze().cpu().numpy()
        s4 = torch.sigmoid(outputs[4]).squeeze().cpu().numpy()
        del outputs
        del input_tensor
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    fused = cv2.resize(fused, (frame.shape[1], frame.shape[0]))
    s1 = cv2.resize(s1, (frame.shape[1], frame.shape[0]))
    s2 = cv2.resize(s2, (frame.shape[1], frame.shape[0]))
    s3 = cv2.resize(s3, (frame.shape[1], frame.shape[0]))
    s4 = cv2.resize(s4, (frame.shape[1], frame.shape[0]))
    edge_maps = (fused, s1, s2, s3, s4)
    anomalies, colored_overlay, total_contours = detect_anomalies_multiscale(edge_maps, adaptive_threshold_system.thresholds, frame.shape)
    adaptive_threshold_system.update_and_adjust([s1, s2, s3, s4], total_contours)
    result_text = f"検出数:{total_contours}"
    for anomaly in anomalies:
        if anomaly['type'] in ['障害物', '亀裂']:
            result_text += f" [{anomaly['type']}@({anomaly['location'][0]},{anomaly['location'][1]})]"
    edge_display = (fused * 255).astype(np.uint8)
    edge_colored = cv2.cvtColor(edge_display, cv2.COLOR_GRAY2BGR)
    processed_frame = cv2.addWeighted(edge_colored, 0.7, colored_overlay, 0.3, 0)
    try:
        img_pil = Image.fromarray(cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(img_pil)
        draw.text((10, 10), f'フレーム: {frame_count}', font=global_font, fill=(255, 255, 255))
        draw.text((10, 35), f'検出数: {total_contours}', font=global_font, fill=(255, 255, 255))
        draw.text((10, 60), f'モデル: OfficialPiDiNet', font=global_font, fill=(0, 255, 0))
        draw.text((10, 85), f'正規化: {norm_mode_runtime}', font=global_font, fill=(0, 200, 255))
        y_offset = 115
        draw.text((10, y_offset), '【検出色分け】', font=global_font, fill=(255, 255, 255))
        draw.text((10, y_offset + 25), '赤: 障害物', font=global_font, fill=(255, 0, 0))
        draw.text((10, y_offset + 50), '青: 細部異常', font=global_font, fill=(0, 0, 255))
        draw.text((10, y_offset + 75), '緑: ひび', font=global_font, fill=(0, 255, 0))
        draw.text((10, y_offset + 100), '黄: 汚れ', font=global_font, fill=(255, 255, 0))
        draw.text((10, y_offset + 125), '橙: 亀裂', font=global_font, fill=(255, 165, 0))
        processed_frame = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    except Exception:
        pass
    return processed_frame, result_text, current_time

# ========== 初期化 ==========
print("=" * 60)
print("PiDiNetによる動画エッジ検出と異常検出(公式実装版)")
print("=" * 60)
print("【概要】PiDiNet公式アーキテクチャで重みを適用します。")
print()

global_font = initialize_font()

# 公式重みのダウンロード(必要に応じて再DL)
download_official_weight(OFFICIAL_MODEL_PATH, OFFICIAL_DOWNLOAD_URL, force=FORCE_REDWN)

# モデル初期化
try:
    model_wrapper = EdgeModelWrapper(official_weight_path=OFFICIAL_MODEL_PATH).to(device)
    model_wrapper.eval()
except SystemExit as e:
    print(str(e))
    print('提案: OFFICIAL_MODEL_PATH を削除またはFORCE_REDWN=1で再ダウンロード後に再実行すること。')
    raise

adaptive_threshold_system = AdaptiveThresholdSystem()

print('\n動画ソースを選択してください:')
print('0: 動画ファイル')
print('1: カメラ')
print('2: サンプル動画')

choice = input('選択: ')

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename()
    if not path:
        exit()
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    if not cap.isOpened():
        cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
else:
    SAMPLE_URL = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'
    SAMPLE_FILE = 'vtest.avi'
    try:
        urllib.request.urlretrieve(SAMPLE_URL, SAMPLE_FILE)
    except Exception as e:
        print(f'サンプル動画ダウンロードエラー: {e}')
    cap = cv2.VideoCapture(SAMPLE_FILE)

if not cap.isOpened():
    print('動画ファイル・カメラを開けませんでした')
    exit()

print('\n=== 動画処理開始 ===')
print('操作方法: q キーで終了')
try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        MAIN_FUNC_DESC = "PiDiNetエッジ検出"
        processed_frame, result, current_time = video_frame_processing(frame)
        cv2.imshow(MAIN_FUNC_DESC, processed_frame)
        if choice == '1':
            print(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], result)
        else:
            print(frame_count, result)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
finally:
    print('\n=== プログラム終了 ===')
    cap.release()
    cv2.destroyAllWindows()