envs / extract_distance.py
caixiaoshun's picture
Create extract_distance.py
542a295 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
extract_distance.py
从 JSON 列表中读取多个视频,对每个视频用 DECA 提取:
- head distance d(基于 cam[:,0] 的 1/scale,经 d = x/(x+k) 归一化)
- 头部中心 (x, y, z):verts 的空间均值,再对时间均值
- 全局 roll:来自 pose[:,0:3](axis-angle),
先转旋转矩阵,再转 Euler,取 Z 轴旋转(单位:度)
在原字段基础上新增:
"head_condition": {
"distance": d,
"x": x_mean,
"y": y_mean,
"z": z_mean,
"roll": roll_deg_mean
}
用法示例:
conda activate deca
python extract_distance.py \
--input_json metadata.json \
--output_json metadata_with_condition.json \
--device cuda:0 \
--rasterizer_type pytorch3d \
--sample_step 10 \
--root /path/to/dataset/root \
--distance_k 1.0
"""
import os
import sys
import json
import math
import argparse
import logging
import torch
from tqdm import tqdm
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_euler_angles
# 保证能找到 decalib
ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
from decalib.deca import DECA
from decalib.datasets import datasets
from decalib.utils.config import cfg as deca_cfg
def compute_video_head_condition(
deca,
video_path: str,
device: str = "cuda",
iscrop: bool = True,
detector: str = "fan",
sample_step: int = 10,
distance_k: float = 1.0,
):
"""
对单个视频:
- 逐帧用 DECA 估计 cam、verts、pose
- 每帧算:
head distance d_frame
head center (x, y, z)
roll_deg_frame(由 pose[:,0:3] -> axis-angle -> R -> Euler)
- 最后在时间维做平均,得到:
distance, x, y, z, roll_deg
返回: dict:
{
"distance": ...,
"x": ...,
"y": ...,
"z": ...,
"roll": ...
}
若失败返回 None
"""
if not os.path.exists(video_path):
logging.warning(f"Video not found: {video_path}")
return None
try:
testdata = datasets.TestData(
video_path,
iscrop=iscrop,
face_detector=detector,
sample_step=sample_step,
)
except Exception as e:
logging.exception(f"Failed to create TestData for {video_path}: {e}")
return None
eps = 1e-6
k = float(distance_k)
d_list = []
xs, ys, zs = [], [], []
rolls = []
for idx in range(len(testdata)):
batch = testdata[idx]
image = batch["image"].to(device)[None, ...] # (1, C, H, W)
with torch.no_grad():
codedict = deca.encode(image)
# ------ head distance from cam ------
cam = codedict["cam"] # (1, 3): [scale, tx, ty]
s = cam[:, 0].mean().item()
x_scale = 1.0 / (s + eps)
if x_scale <= 0 or k <= 0:
d_frame = 0.5
else:
d_frame = x_scale / (x_scale + k)
# ------ head center (x,y,z) from verts ------
opdict, _ = deca.decode(codedict)
verts = opdict["verts"] # (1, N, 3)
center = verts.mean(dim=1)[0].detach().cpu().numpy()
cx, cy, cz = float(center[0]), float(center[1]), float(center[2])
# ------ roll from pose (axis-angle -> R -> Euler Z 轴) ------
if "pose" in codedict:
pose = codedict["pose"] # (1, pose_dim)
global_aa = pose[:, 0:3] # 头部全局旋转
R = axis_angle_to_matrix(global_aa) # (1, 3, 3)
# 使用 "XYZ" 约定,返回 [rx, ry, rz]
euler = matrix_to_euler_angles(R, convention="XYZ")
roll_rad = euler[:, 2].mean().item() # Z 轴旋转
roll_deg = roll_rad * 180.0 / math.pi
else:
roll_deg = 0.0
d_list.append(d_frame)
xs.append(cx)
ys.append(cy)
zs.append(cz)
rolls.append(roll_deg)
if not d_list:
logging.warning(f"No frames processed for {video_path}")
return None
distance = float(sum(d_list) / len(d_list))
x_mean = float(sum(xs) / len(xs))
y_mean = float(sum(ys) / len(ys))
z_mean = float(sum(zs) / len(zs))
roll_mean = float(sum(rolls) / len(rolls))
return {
"distance": distance,
"x": x_mean,
"y": y_mean,
"z": z_mean,
"roll": roll_mean,
}
def parse_args():
parser = argparse.ArgumentParser(
description="Extract head_condition (distance + x,y,z + roll) from videos using DECA."
)
parser.add_argument(
"--input_json",
type=str,
required=True,
help="输入 JSON 文件路径(列表,每个元素里有 file_path 字段)。",
)
parser.add_argument(
"--output_json",
type=str,
required=True,
help="输出 JSON 文件路径(会写入 head_condition 字段)。",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="运行设备,如 cuda:0 或 cpu。",
)
parser.add_argument(
"--rasterizer_type",
type=str,
default="pytorch3d",
choices=["standard", "pytorch3d"],
help="DECA 使用的 rasterizer 类型。",
)
parser.add_argument(
"--detector",
type=str,
default="fan",
help="DECA TestData 使用的人脸检测器(默认 fan)。",
)
parser.add_argument(
"--sample_step",
type=int,
default=10,
help="视频帧采样间隔(每隔 sample_step 帧取一帧)。",
)
parser.add_argument(
"--iscrop",
action="store_true",
help="是否按照 DECA 的方式先做人脸裁剪(默认 False,不加此参数即为 False)。",
)
parser.set_defaults(iscrop=False)
parser.add_argument(
"--root",
type=str,
default="",
help="(可选)视频文件的根目录。如果提供且 file_path 为相对路径,则用 os.path.join(root, file_path) 读取视频。",
)
parser.add_argument(
"--distance_k",
type=float,
default=1.0,
help="归一化函数 d = x / (x + k) 中的 k 值(x = 1/scale)。",
)
return parser.parse_args()
def main():
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
)
args = parse_args()
torch.set_grad_enabled(False)
# 读取 JSON
with open(args.input_json, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("input_json 必须是一个列表,每个元素是一个样本字典。")
# 配置 DECA
device = args.device
logging.info(f"Using device: {device}")
logging.info(f"Using rasterizer_type: {args.rasterizer_type}")
logging.info(f"Using distance_k: {args.distance_k}")
deca_cfg.rasterizer_type = args.rasterizer_type
deca_cfg.model.use_tex = False
deca = DECA(config=deca_cfg, device=device)
new_data = []
for item in tqdm(data, desc="Videos"):
item = dict(item) # 复制一份,避免原对象被就地修改
file_path = item.get("file_path", None)
if file_path is None:
logging.warning("Item has no file_path field, fill head_condition with defaults.")
item["head_condition"] = {
"distance": 0.5,
"x": 0.0,
"y": 0.0,
"z": 0.0,
"roll": 0.0,
}
new_data.append(item)
continue
# 拼 root 路径(只在 file_path 为相对路径时生效)
if args.root and not os.path.isabs(file_path):
video_path = os.path.join(args.root, file_path)
else:
video_path = file_path
cond = compute_video_head_condition(
deca=deca,
video_path=video_path,
device=device,
iscrop=args.iscrop,
detector=args.detector,
sample_step=args.sample_step,
distance_k=args.distance_k,
)
if cond is None:
cond = {
"distance": 0.5,
"x": 0.0,
"y": 0.0,
"z": 0.0,
"roll": 0.0,
}
# 只新增 head_condition,不在顶层散落 distance/x/y/z/roll
item["head_condition"] = {
"distance": float(cond["distance"]),
"x": float(cond["x"]),
"y": float(cond["y"]),
"z": float(cond["z"]),
"roll": float(cond["roll"]),
}
new_data.append(item)
# 写回 JSON
with open(args.output_json, "w", encoding="utf-8") as f:
json.dump(new_data, f, ensure_ascii=False, indent=2)
logging.info(f"Done. Saved to {args.output_json}")
if __name__ == "__main__":
main()