|
|
|
|
|
|
|
|
""" |
|
|
extract_distance.py |
|
|
|
|
|
从 JSON 列表中读取多个视频,对每个视频用 DECA 提取: |
|
|
- head distance d(基于 cam[:,0] 的 1/scale,经 d = x/(x+k) 归一化) |
|
|
- 头部中心 (x, y, z):verts 的空间均值,再对时间均值 |
|
|
- 全局 roll:来自 pose[:,0:3](axis-angle), |
|
|
先转旋转矩阵,再转 Euler,取 Z 轴旋转(单位:度) |
|
|
|
|
|
在原字段基础上新增: |
|
|
"head_condition": { |
|
|
"distance": d, |
|
|
"x": x_mean, |
|
|
"y": y_mean, |
|
|
"z": z_mean, |
|
|
"roll": roll_deg_mean |
|
|
} |
|
|
|
|
|
用法示例: |
|
|
conda activate deca |
|
|
python extract_distance.py \ |
|
|
--input_json metadata.json \ |
|
|
--output_json metadata_with_condition.json \ |
|
|
--device cuda:0 \ |
|
|
--rasterizer_type pytorch3d \ |
|
|
--sample_step 10 \ |
|
|
--root /path/to/dataset/root \ |
|
|
--distance_k 1.0 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import math |
|
|
import argparse |
|
|
import logging |
|
|
|
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_euler_angles |
|
|
|
|
|
|
|
|
ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) |
|
|
if ROOT_DIR not in sys.path: |
|
|
sys.path.insert(0, ROOT_DIR) |
|
|
|
|
|
from decalib.deca import DECA |
|
|
from decalib.datasets import datasets |
|
|
from decalib.utils.config import cfg as deca_cfg |
|
|
|
|
|
|
|
|
def compute_video_head_condition( |
|
|
deca, |
|
|
video_path: str, |
|
|
device: str = "cuda", |
|
|
iscrop: bool = True, |
|
|
detector: str = "fan", |
|
|
sample_step: int = 10, |
|
|
distance_k: float = 1.0, |
|
|
): |
|
|
""" |
|
|
对单个视频: |
|
|
- 逐帧用 DECA 估计 cam、verts、pose |
|
|
- 每帧算: |
|
|
head distance d_frame |
|
|
head center (x, y, z) |
|
|
roll_deg_frame(由 pose[:,0:3] -> axis-angle -> R -> Euler) |
|
|
|
|
|
- 最后在时间维做平均,得到: |
|
|
distance, x, y, z, roll_deg |
|
|
|
|
|
返回: dict: |
|
|
{ |
|
|
"distance": ..., |
|
|
"x": ..., |
|
|
"y": ..., |
|
|
"z": ..., |
|
|
"roll": ... |
|
|
} |
|
|
若失败返回 None |
|
|
""" |
|
|
if not os.path.exists(video_path): |
|
|
logging.warning(f"Video not found: {video_path}") |
|
|
return None |
|
|
|
|
|
try: |
|
|
testdata = datasets.TestData( |
|
|
video_path, |
|
|
iscrop=iscrop, |
|
|
face_detector=detector, |
|
|
sample_step=sample_step, |
|
|
) |
|
|
except Exception as e: |
|
|
logging.exception(f"Failed to create TestData for {video_path}: {e}") |
|
|
return None |
|
|
|
|
|
eps = 1e-6 |
|
|
k = float(distance_k) |
|
|
|
|
|
d_list = [] |
|
|
xs, ys, zs = [], [], [] |
|
|
rolls = [] |
|
|
|
|
|
for idx in range(len(testdata)): |
|
|
batch = testdata[idx] |
|
|
image = batch["image"].to(device)[None, ...] |
|
|
|
|
|
with torch.no_grad(): |
|
|
codedict = deca.encode(image) |
|
|
|
|
|
|
|
|
cam = codedict["cam"] |
|
|
s = cam[:, 0].mean().item() |
|
|
x_scale = 1.0 / (s + eps) |
|
|
if x_scale <= 0 or k <= 0: |
|
|
d_frame = 0.5 |
|
|
else: |
|
|
d_frame = x_scale / (x_scale + k) |
|
|
|
|
|
|
|
|
opdict, _ = deca.decode(codedict) |
|
|
verts = opdict["verts"] |
|
|
center = verts.mean(dim=1)[0].detach().cpu().numpy() |
|
|
cx, cy, cz = float(center[0]), float(center[1]), float(center[2]) |
|
|
|
|
|
|
|
|
if "pose" in codedict: |
|
|
pose = codedict["pose"] |
|
|
global_aa = pose[:, 0:3] |
|
|
R = axis_angle_to_matrix(global_aa) |
|
|
|
|
|
euler = matrix_to_euler_angles(R, convention="XYZ") |
|
|
roll_rad = euler[:, 2].mean().item() |
|
|
roll_deg = roll_rad * 180.0 / math.pi |
|
|
else: |
|
|
roll_deg = 0.0 |
|
|
|
|
|
d_list.append(d_frame) |
|
|
xs.append(cx) |
|
|
ys.append(cy) |
|
|
zs.append(cz) |
|
|
rolls.append(roll_deg) |
|
|
|
|
|
if not d_list: |
|
|
logging.warning(f"No frames processed for {video_path}") |
|
|
return None |
|
|
|
|
|
distance = float(sum(d_list) / len(d_list)) |
|
|
x_mean = float(sum(xs) / len(xs)) |
|
|
y_mean = float(sum(ys) / len(ys)) |
|
|
z_mean = float(sum(zs) / len(zs)) |
|
|
roll_mean = float(sum(rolls) / len(rolls)) |
|
|
|
|
|
return { |
|
|
"distance": distance, |
|
|
"x": x_mean, |
|
|
"y": y_mean, |
|
|
"z": z_mean, |
|
|
"roll": roll_mean, |
|
|
} |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Extract head_condition (distance + x,y,z + roll) from videos using DECA." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--input_json", |
|
|
type=str, |
|
|
required=True, |
|
|
help="输入 JSON 文件路径(列表,每个元素里有 file_path 字段)。", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_json", |
|
|
type=str, |
|
|
required=True, |
|
|
help="输出 JSON 文件路径(会写入 head_condition 字段)。", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cuda:0" if torch.cuda.is_available() else "cpu", |
|
|
help="运行设备,如 cuda:0 或 cpu。", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--rasterizer_type", |
|
|
type=str, |
|
|
default="pytorch3d", |
|
|
choices=["standard", "pytorch3d"], |
|
|
help="DECA 使用的 rasterizer 类型。", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--detector", |
|
|
type=str, |
|
|
default="fan", |
|
|
help="DECA TestData 使用的人脸检测器(默认 fan)。", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sample_step", |
|
|
type=int, |
|
|
default=10, |
|
|
help="视频帧采样间隔(每隔 sample_step 帧取一帧)。", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--iscrop", |
|
|
action="store_true", |
|
|
help="是否按照 DECA 的方式先做人脸裁剪(默认 False,不加此参数即为 False)。", |
|
|
) |
|
|
parser.set_defaults(iscrop=False) |
|
|
parser.add_argument( |
|
|
"--root", |
|
|
type=str, |
|
|
default="", |
|
|
help="(可选)视频文件的根目录。如果提供且 file_path 为相对路径,则用 os.path.join(root, file_path) 读取视频。", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--distance_k", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="归一化函数 d = x / (x + k) 中的 k 值(x = 1/scale)。", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="[%(asctime)s] %(levelname)s: %(message)s", |
|
|
) |
|
|
|
|
|
args = parse_args() |
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
with open(args.input_json, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if not isinstance(data, list): |
|
|
raise ValueError("input_json 必须是一个列表,每个元素是一个样本字典。") |
|
|
|
|
|
|
|
|
device = args.device |
|
|
logging.info(f"Using device: {device}") |
|
|
logging.info(f"Using rasterizer_type: {args.rasterizer_type}") |
|
|
logging.info(f"Using distance_k: {args.distance_k}") |
|
|
|
|
|
deca_cfg.rasterizer_type = args.rasterizer_type |
|
|
deca_cfg.model.use_tex = False |
|
|
deca = DECA(config=deca_cfg, device=device) |
|
|
|
|
|
new_data = [] |
|
|
|
|
|
for item in tqdm(data, desc="Videos"): |
|
|
item = dict(item) |
|
|
file_path = item.get("file_path", None) |
|
|
if file_path is None: |
|
|
logging.warning("Item has no file_path field, fill head_condition with defaults.") |
|
|
item["head_condition"] = { |
|
|
"distance": 0.5, |
|
|
"x": 0.0, |
|
|
"y": 0.0, |
|
|
"z": 0.0, |
|
|
"roll": 0.0, |
|
|
} |
|
|
new_data.append(item) |
|
|
continue |
|
|
|
|
|
|
|
|
if args.root and not os.path.isabs(file_path): |
|
|
video_path = os.path.join(args.root, file_path) |
|
|
else: |
|
|
video_path = file_path |
|
|
|
|
|
cond = compute_video_head_condition( |
|
|
deca=deca, |
|
|
video_path=video_path, |
|
|
device=device, |
|
|
iscrop=args.iscrop, |
|
|
detector=args.detector, |
|
|
sample_step=args.sample_step, |
|
|
distance_k=args.distance_k, |
|
|
) |
|
|
|
|
|
if cond is None: |
|
|
cond = { |
|
|
"distance": 0.5, |
|
|
"x": 0.0, |
|
|
"y": 0.0, |
|
|
"z": 0.0, |
|
|
"roll": 0.0, |
|
|
} |
|
|
|
|
|
|
|
|
item["head_condition"] = { |
|
|
"distance": float(cond["distance"]), |
|
|
"x": float(cond["x"]), |
|
|
"y": float(cond["y"]), |
|
|
"z": float(cond["z"]), |
|
|
"roll": float(cond["roll"]), |
|
|
} |
|
|
|
|
|
new_data.append(item) |
|
|
|
|
|
|
|
|
with open(args.output_json, "w", encoding="utf-8") as f: |
|
|
json.dump(new_data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
logging.info(f"Done. Saved to {args.output_json}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|