Shap-E Text-to-3D Generator による多視点画像からの3次元再構成デモ

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install git+https://github.com/openai/shap-e.git
pip install trimesh matplotlib ipywidgets

Shap-E Text-to-3D Generator

ソースコード


# プログラム名: Shap-E Text-to-3D Generator
# 特徴技術名: Shap-E (条件付き拡散モデル)
# 出典: Jun, H., & Nichol, A. (2023). Shap-E: Generating conditional 3D implicit functions. arXiv:2305.02463
# 特徴機能: 条件付き拡散モデルによる暗黙関数パラメータの生成。テキストプロンプトから3Dアセットを生成し,
#           NeRFとテクスチャメッシュの多表現出力に対応するフレームワークである
# 学習済みモデル: OpenAI公式Shap-E学習済みモデル(text300M: テキスト条件,transmitter: 3D表現変換)
#                  公式URL: https://github.com/openai/shap-e
# 方式設計(本実装の要点):
#   入力: テキストプロンプト(文字列)
#   出力: 3D形状メッシュ(PLY/OBJ/STL)を保存する(本実装では形状のみを対象とする)
#   処理手順:
#     1. テキストから潜在表現を生成
#     2. decode_latent_mesh(公式API)でメッシュ抽出
#     3. メッシュ品質処理(重複面・縮退面の除去,法線処理)と中心化(スケールは変更しない)
#     4. 形式別に保存および可視化
# 調整可能な設定値(本実装の既定値):
#   GUIDANCE_SCALE: 生成品質と多様性のバランス(推奨域に合わせ 5.0)
#   GENERATION_STEPS: Karrasサンプリングのステップ数(64)
# 前準備:
#   pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
#   pip install git+https://github.com/openai/shap-e.git
#   pip install trimesh matplotlib ipywidgets

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # 可視化に使用
import trimesh
from trimesh.repair import fix_normals
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.diffusion.sample import sample_latents
from shap_e.util.notebooks import decode_latent_mesh
import warnings
warnings.filterwarnings("ignore", message="exception rendering with PyTorch3D")
warnings.filterwarnings("ignore", message="falling back on native PyTorch renderer")

# 調整可能な設定値
GUIDANCE_SCALE = 5.0           # 推奨域に合わせる
GENERATION_STEPS = 64          # Karrasステップ数
SEED = 42                      # 乱数シード(再現性)
OUTPUT_DIR = "output"          # 出力ディレクトリ
VISUALIZATION_SIZE = (12, 8)   # 可視化サイズ(インチ)

# 出力ディレクトリ作成
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 乱数シードの設定
torch.manual_seed(SEED)
np.random.seed(SEED)

# ガイダンス表示
print("=== Shap-E Text-to-3D Generator ===")
print("概要: テキストプロンプトから3Dモデルを生成します")
print("操作方法: プロンプトを入力後、Enterキーを押してください")
print("注意事項: 初回実行時はモデルのダウンロードに時間がかかります")
print("=" * 40)

def load_shap_e_models():
    """Shap-Eモデルの読み込みと実行デバイスの決定を行う"""
    print("Shap-Eモデルを読み込んでいます...")

    # デバイス設定(CUDAがあればGPUを使用)
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"使用デバイス: {device} (GPU: {torch.cuda.get_device_name(0)})")
    else:
        device = torch.device('cpu')
        print("使用デバイス: CPU")

    # Shap-Eモデルと拡散設定の読み込み
    xm = load_model('transmitter', device=device)
    model = load_model('text300M', device=device)
    diffusion = diffusion_from_config(load_config('diffusion'))

    return xm, model, diffusion, device

def generate_3d_from_text(prompt, model, diffusion, device):
    """テキストから潜在表現を生成する"""
    print(f"テキストプロンプト: {prompt}")
    print("3D形状を生成しています...")

    # デバイスに応じたFP16/FP32選択
    use_fp16 = (device.type == 'cuda')

    # 潜在表現の生成(Karrasサンプリングを使用)
    latents = sample_latents(
        batch_size=1,
        model=model,
        diffusion=diffusion,
        guidance_scale=GUIDANCE_SCALE,
        model_kwargs=dict(texts=[prompt]),
        progress=True,
        clip_denoised=True,
        use_fp16=use_fp16,
        use_karras=True,
        karras_steps=GENERATION_STEPS,
    )

    return latents

def extract_mesh_from_latent(latent, xm):
    """潜在表現からメッシュを抽出する(公式API: decode_latent_mesh を使用)"""
    # 公式経路でメッシュ抽出
    mesh_obj = decode_latent_mesh(xm, latent)
    tri_mesh = mesh_obj.tri_mesh()

    # 頂点・面の取得(Tensor/Numpy両対応)
    if hasattr(tri_mesh.verts, 'cpu'):
        vertices = tri_mesh.verts.cpu().numpy().astype(np.float32)
    else:
        vertices = np.array(tri_mesh.verts, dtype=np.float32)
    if hasattr(tri_mesh.faces, 'cpu'):
        faces = tri_mesh.faces.cpu().numpy().astype(np.int32)
    else:
        faces = np.array(tri_mesh.faces, dtype=np.int32)

    # Trimeshオブジェクトを作成
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)

    # メッシュの品質処理(重複面・縮退面の除去,法線処理)
    print("メッシュの品質向上処理を実行しています...")
    mesh.remove_duplicate_faces()
    mesh.remove_degenerate_faces()
    fix_normals(mesh)  # 公式API(trimesh.repair)で法線処理

    # 中心化(スケールは変更しない)
    mesh.vertices -= mesh.vertices.mean(axis=0)

    print("Shap-Eによる3D形状生成が完了しました(形状のみ、色情報なし)")
    return mesh

def save_3d_mesh(mesh, filename_base):
    """3Dメッシュを複数形式(PLY/OBJ/STL)で保存する"""
    print("3Dメッシュを保存しています...")

    # PLY形式で保存
    ply_path = os.path.join(OUTPUT_DIR, f"{filename_base}.ply")
    mesh.export(ply_path)
    print(f"PLYファイルを保存しました: {ply_path}")

    # OBJ形式で保存
    obj_path = os.path.join(OUTPUT_DIR, f"{filename_base}.obj")
    mesh.export(obj_path)
    print(f"OBJファイルを保存しました: {obj_path}")

    # STL形式で保存
    stl_path = os.path.join(OUTPUT_DIR, f"{filename_base}.stl")
    mesh.export(stl_path)
    print(f"STLファイルを保存しました: {stl_path}")

def visualize_3d_mesh(mesh, title="Generated 3D Model"):
    """3Dメッシュを可視化する"""
    print("3Dモデルを可視化しています...")

    vertices = mesh.vertices
    faces = mesh.faces

    # 頂点数のチェック(0の場合は可視化を中止)
    if len(vertices) == 0:
        print("警告: 頂点が存在しません。可視化をスキップします。")
        return

    # 3Dプロット
    fig = plt.figure(figsize=VISUALIZATION_SIZE)
    ax = fig.add_subplot(111, projection='3d')

    # メッシュの可視化
    ax.plot_trisurf(vertices[:, 0], vertices[:, 1], vertices[:, 2],
                    triangles=faces, alpha=0.9, cmap='viridis')

    # 軸ラベルとタイトル
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(f'{title}\nVertices: {len(vertices)}, Faces: {len(faces)}')

    # 軸範囲の調整
    x_range = vertices[:, 0].max() - vertices[:, 0].min()
    y_range = vertices[:, 1].max() - vertices[:, 1].min()
    z_range = vertices[:, 2].max() - vertices[:, 2].min()

    # 範囲が0の場合のデフォルト値設定
    max_range = max(x_range, y_range, z_range) / 2.0
    if max_range == 0:
        max_range = 1.0

    mid_x = (vertices[:, 0].max() + vertices[:, 0].min()) * 0.5
    mid_y = (vertices[:, 1].max() + vertices[:, 1].min()) * 0.5
    mid_z = (vertices[:, 2].max() + vertices[:, 2].min()) * 0.5

    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)

    # 範囲の表示
    print(f"X軸範囲: {x_range:.3f}")
    print(f"Y軸範囲: {y_range:.3f}")
    print(f"Z軸範囲: {z_range:.3f}")

    if z_range < 0.01:
        print("警告: Z軸の範囲が小さいです。モデルが平面的になっている可能性があります。")

    plt.tight_layout()
    plt.show()

# モデル読み込み
xm, model, diffusion, device = load_shap_e_models()

# ユーザー入力
prompt = input("3D生成用のテキストプロンプトを入力してください: ")

# 3D形状生成
latents = generate_3d_from_text(prompt, model, diffusion, device)

# メッシュ抽出(公式API経路に一本化)
mesh = extract_mesh_from_latent(latents[0], xm)

# ファイル名ベースの作成
filename_base = prompt.replace(' ', '_').replace('/', '_').replace('\\', '_')[:50]

# 3D出力の保存
save_3d_mesh(mesh, filename_base)

# 可視化
visualize_3d_mesh(mesh, f"Generated: {prompt}")

print("\n処理が完了しました。")
print("生成された3Dモデルは以下の形式で保存されています:")
print("- PLYファイル: 3D編集ソフトで使用可能")
print("- OBJファイル: 汎用3D形式")
print("- STLファイル: 3Dプリンタ用")

# 推奨プロンプト
print("\n3D形状生成のヒント:")
print("- 単一のオブジェクトを指定(例: 'a chair', 'a vase')")
print("- 立体的な形状を明示(例: '3D model of a tree', 'volumetric sphere')")
print("- シンプルで明確な記述を使用")