#!/usr/bin/env python # -*- coding: utf-8 -*- """ extract_distance.py 从 JSON 列表中读取多个视频,对每个视频用 DECA 提取: - head distance d(基于 cam[:,0] 的 1/scale,经 d = x/(x+k) 归一化) - 头部中心 (x, y, z):verts 的空间均值,再对时间均值 - 全局 roll:来自 pose[:,0:3](axis-angle), 先转旋转矩阵,再转 Euler,取 Z 轴旋转(单位:度) 在原字段基础上新增: "head_condition": { "distance": d, "x": x_mean, "y": y_mean, "z": z_mean, "roll": roll_deg_mean } 用法示例: conda activate deca python extract_distance.py \ --input_json metadata.json \ --output_json metadata_with_condition.json \ --device cuda:0 \ --rasterizer_type pytorch3d \ --sample_step 10 \ --root /path/to/dataset/root \ --distance_k 1.0 """ import os import sys import json import math import argparse import logging import torch from tqdm import tqdm from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_euler_angles # 保证能找到 decalib ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) if ROOT_DIR not in sys.path: sys.path.insert(0, ROOT_DIR) from decalib.deca import DECA from decalib.datasets import datasets from decalib.utils.config import cfg as deca_cfg def compute_video_head_condition( deca, video_path: str, device: str = "cuda", iscrop: bool = True, detector: str = "fan", sample_step: int = 10, distance_k: float = 1.0, ): """ 对单个视频: - 逐帧用 DECA 估计 cam、verts、pose - 每帧算: head distance d_frame head center (x, y, z) roll_deg_frame(由 pose[:,0:3] -> axis-angle -> R -> Euler) - 最后在时间维做平均,得到: distance, x, y, z, roll_deg 返回: dict: { "distance": ..., "x": ..., "y": ..., "z": ..., "roll": ... } 若失败返回 None """ if not os.path.exists(video_path): logging.warning(f"Video not found: {video_path}") return None try: testdata = datasets.TestData( video_path, iscrop=iscrop, face_detector=detector, sample_step=sample_step, ) except Exception as e: logging.exception(f"Failed to create TestData for {video_path}: {e}") return None eps = 1e-6 k = float(distance_k) d_list = [] xs, ys, zs = [], [], [] rolls = [] for idx in range(len(testdata)): batch = testdata[idx] image = batch["image"].to(device)[None, ...] # (1, C, H, W) with torch.no_grad(): codedict = deca.encode(image) # ------ head distance from cam ------ cam = codedict["cam"] # (1, 3): [scale, tx, ty] s = cam[:, 0].mean().item() x_scale = 1.0 / (s + eps) if x_scale <= 0 or k <= 0: d_frame = 0.5 else: d_frame = x_scale / (x_scale + k) # ------ head center (x,y,z) from verts ------ opdict, _ = deca.decode(codedict) verts = opdict["verts"] # (1, N, 3) center = verts.mean(dim=1)[0].detach().cpu().numpy() cx, cy, cz = float(center[0]), float(center[1]), float(center[2]) # ------ roll from pose (axis-angle -> R -> Euler Z 轴) ------ if "pose" in codedict: pose = codedict["pose"] # (1, pose_dim) global_aa = pose[:, 0:3] # 头部全局旋转 R = axis_angle_to_matrix(global_aa) # (1, 3, 3) # 使用 "XYZ" 约定,返回 [rx, ry, rz] euler = matrix_to_euler_angles(R, convention="XYZ") roll_rad = euler[:, 2].mean().item() # Z 轴旋转 roll_deg = roll_rad * 180.0 / math.pi else: roll_deg = 0.0 d_list.append(d_frame) xs.append(cx) ys.append(cy) zs.append(cz) rolls.append(roll_deg) if not d_list: logging.warning(f"No frames processed for {video_path}") return None distance = float(sum(d_list) / len(d_list)) x_mean = float(sum(xs) / len(xs)) y_mean = float(sum(ys) / len(ys)) z_mean = float(sum(zs) / len(zs)) roll_mean = float(sum(rolls) / len(rolls)) return { "distance": distance, "x": x_mean, "y": y_mean, "z": z_mean, "roll": roll_mean, } def parse_args(): parser = argparse.ArgumentParser( description="Extract head_condition (distance + x,y,z + roll) from videos using DECA." ) parser.add_argument( "--input_json", type=str, required=True, help="输入 JSON 文件路径(列表,每个元素里有 file_path 字段)。", ) parser.add_argument( "--output_json", type=str, required=True, help="输出 JSON 文件路径(会写入 head_condition 字段)。", ) parser.add_argument( "--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="运行设备,如 cuda:0 或 cpu。", ) parser.add_argument( "--rasterizer_type", type=str, default="pytorch3d", choices=["standard", "pytorch3d"], help="DECA 使用的 rasterizer 类型。", ) parser.add_argument( "--detector", type=str, default="fan", help="DECA TestData 使用的人脸检测器(默认 fan)。", ) parser.add_argument( "--sample_step", type=int, default=10, help="视频帧采样间隔(每隔 sample_step 帧取一帧)。", ) parser.add_argument( "--iscrop", action="store_true", help="是否按照 DECA 的方式先做人脸裁剪(默认 False,不加此参数即为 False)。", ) parser.set_defaults(iscrop=False) parser.add_argument( "--root", type=str, default="", help="(可选)视频文件的根目录。如果提供且 file_path 为相对路径,则用 os.path.join(root, file_path) 读取视频。", ) parser.add_argument( "--distance_k", type=float, default=1.0, help="归一化函数 d = x / (x + k) 中的 k 值(x = 1/scale)。", ) return parser.parse_args() def main(): logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", ) args = parse_args() torch.set_grad_enabled(False) # 读取 JSON with open(args.input_json, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("input_json 必须是一个列表,每个元素是一个样本字典。") # 配置 DECA device = args.device logging.info(f"Using device: {device}") logging.info(f"Using rasterizer_type: {args.rasterizer_type}") logging.info(f"Using distance_k: {args.distance_k}") deca_cfg.rasterizer_type = args.rasterizer_type deca_cfg.model.use_tex = False deca = DECA(config=deca_cfg, device=device) new_data = [] for item in tqdm(data, desc="Videos"): item = dict(item) # 复制一份,避免原对象被就地修改 file_path = item.get("file_path", None) if file_path is None: logging.warning("Item has no file_path field, fill head_condition with defaults.") item["head_condition"] = { "distance": 0.5, "x": 0.0, "y": 0.0, "z": 0.0, "roll": 0.0, } new_data.append(item) continue # 拼 root 路径(只在 file_path 为相对路径时生效) if args.root and not os.path.isabs(file_path): video_path = os.path.join(args.root, file_path) else: video_path = file_path cond = compute_video_head_condition( deca=deca, video_path=video_path, device=device, iscrop=args.iscrop, detector=args.detector, sample_step=args.sample_step, distance_k=args.distance_k, ) if cond is None: cond = { "distance": 0.5, "x": 0.0, "y": 0.0, "z": 0.0, "roll": 0.0, } # 只新增 head_condition,不在顶层散落 distance/x/y/z/roll item["head_condition"] = { "distance": float(cond["distance"]), "x": float(cond["x"]), "y": float(cond["y"]), "z": float(cond["z"]), "roll": float(cond["roll"]), } new_data.append(item) # 写回 JSON with open(args.output_json, "w", encoding="utf-8") as f: json.dump(new_data, f, ensure_ascii=False, indent=2) logging.info(f"Done. Saved to {args.output_json}") if __name__ == "__main__": main()