Zero-DCE++ による夜間・暗所動画画質改善(ソースコードと実行結果)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

Gitのインストール

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。


REM Git をシステム領域にインストール
winget install --scope machine --id Git.Git -e --silent
REM Git のパス設定
set "GIT_PATH=C:\Program Files\Git\cmd"
if exist "%GIT_PATH%" (
    echo "%PATH%" | find /i "%GIT_PATH%" >nul
    if errorlevel 1 setx PATH "%PATH%;%GIT_PATH%" /M >nul
)

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install opencv-python numpy gitpython pillow
git clone https://github.com/Li-Chongyi/Zero-DCE_extension.git

夜間・暗所動画画質改善プログラム

AIの基本的な能力と動画編集応用

AIの基本的な能力の1つは、既存のデータから学習したパターンや特徴を理解し、新しい表現を生成することである。AIによる画像・動画処理技術の発展により、専門的な技術や経験がなくても、暗所や夜間撮影された動画の効率的な画質改善処理が可能となった。

主要技術

本プログラムは、Zero-DCE++(Zero-Reference Deep Curve Estimation)技術を用いた低照度動画画質改善システムである。Zero-DCE++は、参照画像を必要としないゼロ参照学習により、深層曲線推定を用いて画像の明度を調整する技術である。

プログラムは以下の技術を使用している。

参考文献

[1] C. Li, C. Guo, C. C. Loy, "Learning to Enhance Low-Light Image via Zero-Reference Deep Curve Estimation," IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 44, no. 8, pp. 4225-4238, 2021. https://arxiv.org/abs/2103.00860

ソースコード


# 夜間・暗所動画画質改善プログラム
# - 特徴技術名: Zero-DCE++ (Zero-Reference Deep Curve Estimation)
# - 出典: https://github.com/Li-Chongyi/Zero-DCE_extension
# - ライセンス: MIT License (ソースコード)
# - 特徴技術および学習済モデルの利用制限: 学術研究目的のみ (学習済みモデルの利用)。商用利用は作者への連絡が必須 (出典リポジトリのREADME.mdに記載)。必ず利用者自身で利用制限を確認すること。
# - 特徴機能: ゼロ参照深層曲線推定による低照度動画の画質改善。明暗混在シーンでも局所的な明度調整により結果を実現
# - AI学習済みモデル: Zero-DCE++事前学習モデル、低照度画像改善に特化、曲線推定により明度調整を実現
# - 入力と出力: 入力: 動画(ファイル/カメラ/サンプル)、出力: 改善前後の映像を並べたリアルタイム比較表示
# - 前準備: pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
# pip install opencv-python numpy gitpython pillow
# git clone https://github.com/Li-Chongyi/Zero-DCE_extension.git
# - 方式設計:
#   - 関連利用技術: OpenCV(動画処理)、PyTorch(深層学習フレームワーク)
#   - 処理手順: 1.動画ソース選択、2.フレーム取得、3.解像度調整、4.Zero-DCE++モデル適用、5.改善前後の映像を連結、6.リアルタイム表示
#   - 前処理、後処理: 前処理:RGB正規化(0-1範囲)、動的解像度調整、後処理:なし
#   - 調整を必要とする設定値: MAX_RESOLUTION(最大解像度、現在1280、範囲720-1920、処理速度とのバランス)

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageDraw, ImageFont
import urllib.request
import time
from datetime import datetime
from collections import OrderedDict

# 設定値
MAX_RESOLUTION = 1280  # 最大解像度(720-1920)
MODEL_PATH = "Zero-DCE_extension/Zero-DCE++/snapshots_Zero_DCE++/Epoch99.pth"
REPO_URL = "https://github.com/Li-Chongyi/Zero-DCE_extension.git"
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
FONT_SIZE = 20

# GPU/CPU自動選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')

# Depthwise Separable Convolutionブロックを定義
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.depth_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=True)
        self.point_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.depth_conv(x)
        x = self.point_conv(x)
        x = self.relu(x)
        return x

# Zero-DCE++モデル定義(最終層のReLU問題を修正した公式準拠版)
class DCENet(nn.Module):
    def __init__(self, n_curves=8):
        super(DCENet, self).__init__()
        self.n_curves = n_curves

        self.e_conv1 = ConvBlock(3, 32)
        self.e_conv2 = ConvBlock(32, 32)
        self.e_conv3 = ConvBlock(32, 32)
        self.e_conv4 = ConvBlock(32, 32)
        self.e_conv5 = ConvBlock(64, 32) # 入力: x3(32) + x4(32)
        self.e_conv6 = ConvBlock(64, 32) # 入力: x2(32) + x5(32)

        # 最終層e_conv7はReLUを適用しないため、ConvBlockを使わず個別に定義
        self.e_conv7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, groups=64, bias=True),
            nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
        )

    def forward(self, x):
        x1 = self.e_conv1(x)
        x2 = self.e_conv2(x1)
        x3 = self.e_conv3(x2)
        x4 = self.e_conv4(x3)
        x5 = self.e_conv5(torch.cat([x3, x4], 1))
        x6 = self.e_conv6(torch.cat([x2, x5], 1))
        # 最終層の出力(ReLUなし)を直接tanhに渡す
        A_x = self.e_conv7(torch.cat([x1, x6], 1))
        A = torch.tanh(A_x)

        # 1つの補正マップAをn_curves回、繰り返し適用する
        enhanced = x
        for _ in range(self.n_curves):
            enhanced = enhanced + A * (torch.pow(enhanced, 2) - enhanced)

        return enhanced, A

def video_frame_processing(frame):
    global frame_count, model
    current_time = time.time()
    frame_count += 1

    # 解像度調整
    h, w = frame.shape[:2]
    original_size = None
    if max(h, w) > MAX_RESOLUTION:
        scale = MAX_RESOLUTION / max(h, w)
        new_w, new_h = int(w * scale), int(h * scale)
        resized = cv2.resize(frame, (new_w, new_h))
        original_size = (w, h)
    else:
        resized = frame

    # 表示用にオリジナルフレームを保持
    original_for_display = resized.copy()

    # 前処理
    frame_rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
    frame_normalized = frame_rgb.astype(np.float32) / 255.0
    frame_tensor = torch.from_numpy(frame_normalized).permute(2, 0, 1).unsqueeze(0).to(device)

    # モデル推論
    with torch.no_grad():
        enhanced, _ = model(frame_tensor)
        enhanced = torch.clamp(enhanced, 0, 1)

    # 後処理(モデル出力をBGR形式に変換)
    enhanced_np = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy()
    enhanced_bgr = cv2.cvtColor((enhanced_np * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)

    # 元解像度に復元(改善後のみ)
    if original_size:
        enhanced_bgr = cv2.resize(enhanced_bgr, original_size)
        # 表示用オリジナルも同じサイズに復元
        original_for_display = cv2.resize(original_for_display, original_size)

    # 明度改善率の計算
    original_brightness = np.mean(cv2.cvtColor(original_for_display, cv2.COLOR_BGR2GRAY))
    enhanced_brightness = np.mean(cv2.cvtColor(enhanced_bgr, cv2.COLOR_BGR2GRAY))
    improvement_ratio = (enhanced_brightness / max(original_brightness, 1)) * 100

    result = f"輝度比率: {improvement_ratio:.1f}%"
    return enhanced_bgr, original_for_display, result, current_time

# ガイダンス表示
print("=" * 50)
print("夜間・暗所動画画質改善プログラム")
print("=" * 50)
print("概要: Zero-DCE++技術により低照度動画の画質を改善します")
print("注意事項: 初回実行時はモデルのダウンロードに時間がかかります")
print("=" * 50)

# 入力選択
print("0: 動画ファイル")
print("1: カメラ")
print("2: サンプル動画")

choice = input("選択: ")

frame_count = 0
results_log = []

# モデルの初期化
model = DCENet(n_curves=8).to(device)
print("モデル構造: Zero-DCE++(公式準拠)")

# モデルのダウンロードと読み込み
if not os.path.exists(MODEL_PATH):
    print("学習済みモデルをダウンロード中...")
    import git
    try:
        git.Repo.clone_from(REPO_URL, "Zero-DCE_extension")
        print("ダウンロード完了")
    except Exception as e:
        print(f"モデルのダウンロードに失敗しました: {e}")
        exit()

# モデル重みの読み込み
try:
    saved_state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=True)

    new_state_dict = OrderedDict()
    is_parallel = all(key.startswith('module.') for key in saved_state_dict.keys())

    for k, v in saved_state_dict.items():
        name = k[7:] if is_parallel else k

        # e_conv7のキー名を新しい構造にマッピングする
        if name.startswith('e_conv7.'):
            if 'depth_conv' in name:
                name = name.replace('depth_conv', '0')
            elif 'point_conv' in name:
                name = name.replace('point_conv', '1')

        new_state_dict[name] = v

    state_dict_to_load = new_state_dict

    # 厳格モード(strict=True)でモデルに重みを読み込む
    model.load_state_dict(state_dict_to_load, strict=True)
    print("学習済みモデルを正常に読み込みました。")

except Exception as e:
    print(f"モデルの読み込みに失敗しました: {e}")
    print("モデルのアーキテクチャと学習済みデータの間に不一致がある可能性があります。")
    exit()

model.eval()

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename()
    if not path:
        exit()
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    if not cap.isOpened():
        cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
else:
    # サンプル動画ダウンロード・処理
    SAMPLE_URL = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'
    SAMPLE_FILE = 'vtest.avi'
    urllib.request.urlretrieve(SAMPLE_URL, SAMPLE_FILE)
    cap = cv2.VideoCapture(SAMPLE_FILE)

if not cap.isOpened():
    print('動画ファイル・カメラを開けませんでした')
    exit()

# メイン処理
MAIN_FUNC_DESC = "画質改善 比較 (左:オリジナル, 右:改善後)"
print('\n=== 動画処理開始 ===')
print('操作方法:')
print('  q キー: プログラム終了')
try:
    # フォントの存在確認
    font = None
    if os.path.exists(FONT_PATH):
        font = ImageFont.truetype(FONT_PATH, FONT_SIZE)
    else:
        print(f"警告: フォントファイルが見つかりません ({FONT_PATH})。テキスト表示は行われません。")

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        processed_frame, original_frame, result, current_time = video_frame_processing(frame)

        # 改善前と改善後を左右に連結
        h1, w1, _ = original_frame.shape
        h2, w2, _ = processed_frame.shape
        if h1 != h2 or w1 != w2:
             processed_frame = cv2.resize(processed_frame, (w1, h1))

        comparison_image_bgr = np.hstack((original_frame, processed_frame))

        # Pillowを使用してウィンドウに情報を描画(フォントが存在する場合のみ)
        if font:
            img_pil = Image.fromarray(cv2.cvtColor(comparison_image_bgr, cv2.COLOR_BGR2RGB))
            draw = ImageDraw.Draw(img_pil)
            info_text = f"{datetime.fromtimestamp(current_time).strftime('%Y-%m-%d %H:%M:%S')} | {result}"
            draw.text((10, 5), info_text, font=font, fill=(0, 255, 0))
            draw.text((10, 35), "オリジナル", font=font, fill=(0, 255, 0))
            draw.text((w1 + 10, 35), "改善後", font=font, fill=(0, 255, 0))
            comparison_image = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
        else:
            # フォントが存在しない場合、テキスト描画は行わず、連結した画像をそのまま使用する
            comparison_image = comparison_image_bgr

        cv2.imshow(MAIN_FUNC_DESC, comparison_image)

        if choice == '1':
            print(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], result)
        else:
            print(frame_count, result)
        results_log.append(result)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
finally:
    print('\n=== プログラム終了 ===')
    cap.release()
    cv2.destroyAllWindows()
    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('=== 結果 ===\n')
            f.write(f'処理フレーム数: {frame_count}\n')
            f.write(f'使用デバイス: {str(device).upper()}\n')
            if device.type == 'cuda':
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            f.write('\n')
            f.write('\n'.join(results_log))
        print(f'\n処理結果をresult.txtに保存しました')