BasicSR による超解像(ソースコードと実行結果)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install basicsr opencv-python numpy pillow

BasicSR動画超解像処理プログラム

概要

このプログラムは、BasicSRフレームワークを基盤として動画の超解像処理を実現する。RRDBNetアーキテクチャを採用し、入力動画の各フレームを4倍の解像度に拡大する。動画ファイル、ウェブカメラ、サンプル動画の3種類の入力ソースに対応し、リアルタイムで処理結果を表示しながら、処理情報をテキストファイルに記録する。

主要技術

BasicSR (Basic Super-Resolution framework)

画像・動画復元タスクのための汎用的な深層学習フレームワーク[1]。様々な超解像モデルの実装基盤として広く利用されており、本プログラムではその中核アーキテクチャであるRRDBNetを採用している。

RRDB (Residual-in-Residual Dense Block) Network

ESRGANで提案された深層ネットワークアーキテクチャ[2]。複数のResidual Dense Blockを階層的に組み合わせることで、画像の詳細な特徴を保持しながら高品質な超解像を実現する。本実装では23個のRRDBブロックを使用し、64チャンネルの特徴マップを処理する。

技術的特徴

実装の特色

参考文献

[1] Wang, X., et al. (2020). BasicSR: Open Source Image and Video Restoration Toolbox. GitHub repository. https://github.com/XPixelGroup/BasicSR

[2] Wang, X., et al. (2018). ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. In Proceedings of the European Conference on Computer Vision (ECCV) Workshops. https://arxiv.org/abs/1809.00219

ソースコード


# プログラム名: BasicSR動画超解像処理プログラム (高機能版)
# 特徴技術名: BasicSR (Basic Super-Resolution framework)
# 特徴機能: RRDBNetを使用した動的モデル選択・高品質動画生成
# 入力: 動画(動画ファイル、カメラ、サンプル動画)
# 出力: リアルタイム表示、処理結果をresult.txtに保存、音声付き動画ファイル(MP4)を生成
# 前準備:
#   pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
#   pip install basicsr opencv-python numpy pillow scikit-image requests

import sys, types
try:
    import torchvision.transforms.functional_tensor as _ft  # noqa
except ModuleNotFoundError:
    import torchvision.transforms.functional as _F
    _module = types.ModuleType('torchvision.transforms.functional_tensor')
    _module.rgb_to_grayscale = _F.rgb_to_grayscale
    sys.modules['torchvision.transforms.functional_tensor'] = _module

import cv2
import torch
import numpy as np
import os
import tkinter as tk
from tkinter import filedialog
import urllib.request
import time
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
from basicsr.archs.rrdbnet_arch import RRDBNet
# from basicsr.utils.download_util import load_file_from_url # requestsによる自前実装に変更
import requests # ファイルダウンロード機能のために追加
import subprocess
from pathlib import Path # ファイルパス表現方法の統一のために追加
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# --- 定数の一元管理 (pathlib対応) ---
WEIGHTS_DIR = Path('weights')
RESULT_FILE = Path('result.txt')
OUTPUT_VIDEO_FILE = Path('enhanced_output.mp4')
SAMPLE_FILE = Path('vtest.avi')
FONT_PATH = Path('C:/Windows/Fonts/meiryo.ttc')
FONT_SIZE = 20
FONT_COLOR = (0, 255, 0)
TEXT_POSITION = (10, 10)
MAIN_FUNC_DESC = "BasicSR超解像処理"

# --- 詳細なモデル選択機能 ---
MODEL_INFO = {
    'RealESRGAN_x4plus': {
        'name': 'RealESRGAN x4plus',
        'description': '汎用実写画像向け、標準品質',
        'scale': 4,
        'blocks': 23,
        'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
    },
    'RealESRGAN_x4plus_anime_6B': {
        'name': 'RealESRGAN x4plus Anime 6B',
        'description': 'アニメ画像特化、軽量モデル',
        'scale': 4,
        'blocks': 6,
        'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth',
    }
}

# --- ファイルダウンロード機能の実装変更 ---
def download_file_from_url(url, model_dir, progress=True, file_name=None):
    model_dir.mkdir(parents=True, exist_ok=True)
    if file_name is None:
        file_name = url.split('/')[-1]
    file_path = model_dir / file_name

    if file_path.exists():
        return file_path

    print(f'ダウンロード中: {url}')
    response = requests.get(url, stream=True)
    response.raise_for_status()

    total_size = int(response.headers.get('content-length', 0))
    downloaded = 0

    with open(file_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                f.write(chunk)
                downloaded += len(chunk)
                if progress and total_size > 0:
                    percent = (downloaded / total_size) * 100
                    print(f'\rダウンロード進捗: {percent:.1f}%', end='', flush=True)
    if progress:
        print('\nダウンロード完了')
    return file_path

# GPU/CPU自動選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')

# --- GPUメモリに応じた動的な設定 ---
USE_HALF = False
if device.type == 'cuda':
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
    if gpu_memory_gb >= 4:
        USE_HALF = True
        print(f'GPUメモリ ({gpu_memory_gb:.1f}GB) を検出しました。半精度(FP16)を有効化します。')
    else:
        print(f'GPUメモリ ({gpu_memory_gb:.1f}GB) が4GB未満のため、半精度(FP16)を無効化します。')

# FFmpeg/ffprobe利用可能性チェック
FFMPEG_AVAILABLE = False
try:
    subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True, text=True)
    FFMPEG_AVAILABLE = True
except (subprocess.CalledProcessError, FileNotFoundError):
    print('警告: ffmpegが見つかりません。動画出力機能は無効になります。')

# 超解像処理プロセッサクラス
class SuperResolutionProcessor:
    def __init__(self, model, device, use_half=False):
        self.model = model
        self.device = device
        self.use_half = use_half and self.device.type == 'cuda'

        self.model.to(self.device)
        if self.use_half:
            self.model.half()

    def process(self, img_tensor):
        if self.use_half:
            img_tensor = img_tensor.half()

        # --- GPUメモリ不足時のフォールバック処理 ---
        try:
            with torch.no_grad():
                output = self.model(img_tensor)
        except RuntimeError as e:
            if 'out of memory' in str(e).lower():
                print('警告: GPUメモリが不足しました。CPUに切り替えて処理を続行します。')
                torch.cuda.empty_cache()
                self.device = torch.device('cpu')
                self.model.to(self.device).float() # CPUでは半精度は使わない
                self.use_half = False
                img_tensor = img_tensor.to(self.device)
                with torch.no_grad():
                    output = self.model(img_tensor)
            else:
                raise e # その他のエラーは再送出

        return output

# --- メインプログラム ---
print('\n=== BasicSR動画超解像処理プログラム (高機能版) ===')

# モデル選択
print('=== モデル選択 ===')
models = list(MODEL_INFO.keys())
for i, model_key in enumerate(models, 1):
    info = MODEL_INFO[model_key]
    print(f'{i}. {info["name"]} ({info["description"]})')

while True:
    try:
        model_choice_idx = int(input(f'モデルを選択してください (1-{len(models)}): ')) - 1
        if 0 <= model_choice_idx < len(models):
            selected_model_key = models[model_choice_idx]
            break
        else:
            print(f'1から{len(models)}の間で選択してください。')
    except ValueError:
        print('数値を入力してください。')

model_info = MODEL_INFO[selected_model_key]
print(f'\'{model_info["name"]}\' をロード中...')

# モデル初期化
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=model_info['blocks'], num_grow_ch=32, scale=model_info['scale'])
ckpt_path = download_file_from_url(url=model_info['url'], model_dir=WEIGHTS_DIR, file_name=f"{selected_model_key}.pth")
ckpt = torch.load(ckpt_path, map_location='cpu')
state = ckpt.get('params_ema') or ckpt.get('params') or ckpt
model.load_state_dict(state, strict=True)
model.eval()

# プロセッサのインスタンス化
processor = SuperResolutionProcessor(model, device, use_half=USE_HALF)

# フォント設定 (pathlib対応)
if FONT_PATH.exists():
    font = ImageFont.truetype(str(FONT_PATH), FONT_SIZE)
    use_japanese_font = True
else:
    print('警告: 日本語フォントが見つかりません。英語表示になります。')
    use_japanese_font = False

frame_count = 0
results_log = []

def video_frame_processing(frame, processor_instance):
    global frame_count
    current_time = time.time()
    frame_count += 1

    # BGR→RGB変換
    img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img_tensor = torch.from_numpy(img).permute(2, 0, 1).float().div(255).unsqueeze(0).to(processor_instance.device)

    # 推論実行
    output = processor_instance.process(img_tensor)

    # 後処理:RGB→BGR、uint8化
    output = output.squeeze(0).clamp(0, 1).mul(255).round().to(torch.uint8).cpu().permute(1, 2, 0).numpy()
    processed_frame = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

    # --- 客観的な品質評価指標の導入 ---
    original_resized = cv2.resize(frame, (processed_frame.shape[1], processed_frame.shape[0]), interpolation=cv2.INTER_LANCZOS4)
    psnr_val = psnr(original_resized, processed_frame, data_range=255)
    ssim_val = ssim(original_resized, processed_frame, channel_axis=2, data_range=255)

    # 情報テキスト追加
    info_text = f'Frame: {frame_count} | PSNR: {psnr_val:.2f}dB | SSIM: {ssim_val:.4f}'
    if use_japanese_font:
        img_pil = Image.fromarray(cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(img_pil)
        draw.text(TEXT_POSITION, info_text, font=font, fill=FONT_COLOR)
        processed_frame = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    else:
        cv2.putText(processed_frame, info_text, TEXT_POSITION, cv2.FONT_HERSHEY_SIMPLEX, 0.6, FONT_COLOR, 2)

    result = f'解像度: {frame.shape[1]}x{frame.shape[0]} → {processed_frame.shape[1]}x{processed_frame.shape[0]}, PSNR: {psnr_val:.2f}dB, SSIM: {ssim_val:.4f}'
    return processed_frame, result, current_time

print("\n入力ソースを選択してください:")
print("0: 動画ファイル")
print("1: カメラ")
print("2: サンプル動画")

choice = input("選択: ")
original_path = None

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    original_path_str = filedialog.askopenfilename(
        title="動画ファイルを選択",
        filetypes=[("Video files", "*.mp4 *.avi *.mov *.mkv")]
    )
    if not original_path_str:
        exit()
    original_path = Path(original_path_str)
    cap = cv2.VideoCapture(str(original_path))
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    if not cap.isOpened():
        cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
else:
    SAMPLE_URL = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'
    print('サンプル動画をダウンロード中...')
    urllib.request.urlretrieve(SAMPLE_URL, str(SAMPLE_FILE))
    original_path = SAMPLE_FILE
    cap = cv2.VideoCapture(str(original_path))

if not cap.isOpened():
    print('動画ファイル・カメラを開けませんでした')
    exit()

# --- 外部ツール(FFmpeg)連携による動画再構築機能 (pathlib対応) ---
frames_dir = None
if choice != '1' and FFMPEG_AVAILABLE:
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    frames_dir = Path(f'frames_{timestamp}')
    frames_dir.mkdir(exist_ok=True)
    print(f'処理フレームは \'{frames_dir}\' に一時保存されます。')

# メイン処理
print('\n=== 動画処理開始 ===')
print('操作方法:')
print('  q キー: プログラム終了')
try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        processed_frame, result, current_time = video_frame_processing(frame, processor)
        cv2.imshow(MAIN_FUNC_DESC, processed_frame)

        if choice == '1':
            log_entry = f'{datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]} {result}'
        else:
            log_entry = f'Frame {frame_count}, {result}'

        print(log_entry)
        results_log.append(log_entry)

        if frames_dir:
            save_path = frames_dir / f'{frame_count:06d}.png'
            cv2.imwrite(str(save_path), processed_frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
finally:
    print('\n=== プログラム終了 ===')
    cap.release()
    cv2.destroyAllWindows()

    if frames_dir and frame_count > 0:
        print('処理済みフレームと音声を結合して動画ファイルを生成中...')
        try:
            probe_cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=r_frame_rate', '-of', 'default=noprint_wrappers=1:nokey=1', str(original_path)]
            res = subprocess.run(probe_cmd, capture_output=True, text=True, check=True)
            framerate = res.stdout.strip()
        except (subprocess.CalledProcessError, FileNotFoundError):
            framerate = '25'
            print(f'ffprobeでのフレームレート取得に失敗。デフォルト値 ({framerate}) を使用します。')

        ffmpeg_cmd = [
            'ffmpeg', '-y', '-framerate', framerate, '-i', str(frames_dir / '%06d.png'),
            '-i', str(original_path), '-map', '0:v', '-map', '1:a?', '-c:v', 'libx264',
            '-pix_fmt', 'yuv420p', '-c:a', 'aac', '-shortest', str(OUTPUT_VIDEO_FILE)
        ]
        result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
        if result.returncode == 0:
            print(f'動画を \'{OUTPUT_VIDEO_FILE}\' に保存しました。')
        else:
            print('動画の生成に失敗しました。FFmpegのエラー:')
            print(result.stderr)

    if choice == '2' and SAMPLE_FILE.exists():
        SAMPLE_FILE.unlink()

    if results_log:
        with RESULT_FILE.open('w', encoding='utf-8') as f:
            f.write('=== 結果 ===\n')
            f.write(f'使用モデル: {model_info["name"]}\n')
            f.write(f'処理フレーム数: {frame_count}\n')
            f.write(f'最終使用デバイス: {str(processor.device).upper()}\n')
            if 'cuda' in str(processor.device):
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            f.write('\n' + '\n'.join(results_log))
        print(f'\n処理結果を \'{RESULT_FILE}\' に保存しました。')