Spaces:
Running on Zero
Running on Zero
Update inference.py
Browse files- inference.py +9 -49
inference.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import argparse
|
| 2 |
import os
|
| 3 |
-
import shutil
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Dict, Optional, List, Tuple
|
| 6 |
from collections import defaultdict
|
|
@@ -36,10 +35,6 @@ session = new_session(providers=providers)
|
|
| 36 |
weight_dtype = torch.float16
|
| 37 |
|
| 38 |
|
| 39 |
-
# ============================================================
|
| 40 |
-
# Config
|
| 41 |
-
# ============================================================
|
| 42 |
-
|
| 43 |
@dataclass
|
| 44 |
class TestConfig:
|
| 45 |
pretrained_model_name_or_path: str
|
|
@@ -63,16 +58,11 @@ class TestConfig:
|
|
| 63 |
with_smpl: Optional[bool]
|
| 64 |
recon_opt: Dict
|
| 65 |
|
| 66 |
-
# New two-stage fields
|
| 67 |
run_mode: str = "full" # full | generate | reconstruct
|
| 68 |
multiview_tmp_dir: str = "./multiview"
|
| 69 |
prefer_edited_views: bool = True
|
| 70 |
|
| 71 |
|
| 72 |
-
# ============================================================
|
| 73 |
-
# Image helpers
|
| 74 |
-
# ============================================================
|
| 75 |
-
|
| 76 |
def convert_to_numpy(tensor):
|
| 77 |
return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
| 78 |
|
|
@@ -92,10 +82,6 @@ def save_image_tensor(tensor, fp):
|
|
| 92 |
return ndarr
|
| 93 |
|
| 94 |
|
| 95 |
-
# ============================================================
|
| 96 |
-
# Multiview storage helpers
|
| 97 |
-
# ============================================================
|
| 98 |
-
|
| 99 |
def ensure_dir(path: Path):
|
| 100 |
path.mkdir(parents=True, exist_ok=True)
|
| 101 |
|
|
@@ -108,24 +94,20 @@ def save_multiview_scene(multiview_root: str, scene: str, colors: List[Image.Ima
|
|
| 108 |
ensure_dir(raw_dir)
|
| 109 |
ensure_dir(edit_dir)
|
| 110 |
|
| 111 |
-
# Clean previous files to avoid stale leftovers
|
| 112 |
for folder in (raw_dir, edit_dir):
|
| 113 |
for p in folder.glob("*"):
|
| 114 |
if p.is_file():
|
| 115 |
p.unlink()
|
| 116 |
|
| 117 |
for idx, img in enumerate(colors):
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
img.save(raw_color)
|
| 121 |
-
img.save(edit_color)
|
| 122 |
|
| 123 |
for idx, img in enumerate(normals):
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
img.save(raw_normal)
|
| 127 |
-
img.save(edit_normal)
|
| 128 |
|
|
|
|
| 129 |
meta = {
|
| 130 |
"scene": scene,
|
| 131 |
"num_colors": len(colors),
|
|
@@ -133,7 +115,6 @@ def save_multiview_scene(multiview_root: str, scene: str, colors: List[Image.Ima
|
|
| 133 |
"source": "PSHuman two-stage inference",
|
| 134 |
}
|
| 135 |
with open(scene_dir / "meta.json", "w", encoding="utf-8") as f:
|
| 136 |
-
import json
|
| 137 |
json.dump(meta, f, indent=2)
|
| 138 |
|
| 139 |
|
|
@@ -159,10 +140,6 @@ def load_multiview_scene(multiview_root: str, scene: str, prefer_edit=True) -> T
|
|
| 159 |
return colors, normals
|
| 160 |
|
| 161 |
|
| 162 |
-
# ============================================================
|
| 163 |
-
# Pipeline helpers
|
| 164 |
-
# ============================================================
|
| 165 |
-
|
| 166 |
def load_pshuman_pipeline(cfg):
|
| 167 |
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
| 168 |
cfg.pretrained_model_name_or_path,
|
|
@@ -174,25 +151,17 @@ def load_pshuman_pipeline(cfg):
|
|
| 174 |
return pipeline
|
| 175 |
|
| 176 |
|
| 177 |
-
def extract_scene_views_for_case(
|
| 178 |
-
batch,
|
| 179 |
-
out,
|
| 180 |
-
imgs_in,
|
| 181 |
-
i: int,
|
| 182 |
-
num_views: int,
|
| 183 |
-
):
|
| 184 |
normals_pred = out[: out.shape[0] // 2]
|
| 185 |
images_pred = out[out.shape[0] // 2:]
|
| 186 |
|
| 187 |
scene = batch['filename'][i].split('.')[0]
|
| 188 |
-
|
| 189 |
normals, colors = [], []
|
| 190 |
|
| 191 |
for j in range(num_views):
|
| 192 |
idx = i * num_views + j
|
| 193 |
normal = normals_pred[idx]
|
| 194 |
|
| 195 |
-
# Fix from original code: use scene-local first input image
|
| 196 |
if j == 0:
|
| 197 |
color = imgs_in[i * num_views].to(out.device)
|
| 198 |
else:
|
|
@@ -214,35 +183,29 @@ def extract_scene_views_for_case(
|
|
| 214 |
|
| 215 |
normals.append(normal)
|
| 216 |
|
| 217 |
-
# Preserve original PSHuman behavior
|
| 218 |
if len(normals) >= 2:
|
| 219 |
normals[0][:, :256, 256:512] = normals[-1]
|
| 220 |
|
| 221 |
-
# Original code keeps first 6 views only
|
| 222 |
colors_pil = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
|
| 223 |
normals_pil = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
|
| 224 |
|
| 225 |
return scene, colors_pil, normals_pil
|
| 226 |
|
| 227 |
|
| 228 |
-
# ============================================================
|
| 229 |
-
# Main inference logic
|
| 230 |
-
# ============================================================
|
| 231 |
-
|
| 232 |
def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
|
| 233 |
-
pipeline
|
|
|
|
| 234 |
|
| 235 |
if cfg.seed is None:
|
| 236 |
generator = None
|
| 237 |
else:
|
| 238 |
-
generator = torch.Generator(device=
|
| 239 |
|
| 240 |
images_cond, pred_cat = [], defaultdict(list)
|
| 241 |
|
| 242 |
for case_id, batch in tqdm(enumerate(dataloader)):
|
| 243 |
images_cond.append(batch['imgs_in'][:, 0])
|
| 244 |
|
| 245 |
-
# Reconstruct-only path: skip diffusion, load saved views instead
|
| 246 |
if cfg.run_mode == "reconstruct":
|
| 247 |
scene = batch['filename'][0].split('.')[0]
|
| 248 |
colors, normals = load_multiview_scene(
|
|
@@ -315,7 +278,6 @@ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save
|
|
| 315 |
vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
|
| 316 |
save_image_tensor(vis_, out_filename)
|
| 317 |
|
| 318 |
-
# concat mode is only for legacy visualization
|
| 319 |
continue
|
| 320 |
|
| 321 |
elif cfg.save_mode == 'rgb':
|
|
@@ -332,7 +294,6 @@ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save
|
|
| 332 |
save_multiview_scene(cfg.multiview_tmp_dir, scene, colors, normals)
|
| 333 |
continue
|
| 334 |
|
| 335 |
-
# full mode: original one-pass behavior
|
| 336 |
pose = econdata.__getitem__(case_id)
|
| 337 |
carving.optimize_case(scene, pose, colors, normals)
|
| 338 |
torch.cuda.empty_cache()
|
|
@@ -342,7 +303,6 @@ def main(cfg: TestConfig):
|
|
| 342 |
if cfg.seed is not None:
|
| 343 |
set_seed(cfg.seed)
|
| 344 |
|
| 345 |
-
# Reconstruct mode does not need the diffusion pipeline at all
|
| 346 |
pipeline = None if cfg.run_mode == "reconstruct" else load_pshuman_pipeline(cfg)
|
| 347 |
|
| 348 |
if cfg.with_smpl:
|
|
|
|
| 1 |
import argparse
|
| 2 |
import os
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Dict, Optional, List, Tuple
|
| 5 |
from collections import defaultdict
|
|
|
|
| 35 |
weight_dtype = torch.float16
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
@dataclass
|
| 39 |
class TestConfig:
|
| 40 |
pretrained_model_name_or_path: str
|
|
|
|
| 58 |
with_smpl: Optional[bool]
|
| 59 |
recon_opt: Dict
|
| 60 |
|
|
|
|
| 61 |
run_mode: str = "full" # full | generate | reconstruct
|
| 62 |
multiview_tmp_dir: str = "./multiview"
|
| 63 |
prefer_edited_views: bool = True
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def convert_to_numpy(tensor):
|
| 67 |
return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
| 68 |
|
|
|
|
| 82 |
return ndarr
|
| 83 |
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
def ensure_dir(path: Path):
|
| 86 |
path.mkdir(parents=True, exist_ok=True)
|
| 87 |
|
|
|
|
| 94 |
ensure_dir(raw_dir)
|
| 95 |
ensure_dir(edit_dir)
|
| 96 |
|
|
|
|
| 97 |
for folder in (raw_dir, edit_dir):
|
| 98 |
for p in folder.glob("*"):
|
| 99 |
if p.is_file():
|
| 100 |
p.unlink()
|
| 101 |
|
| 102 |
for idx, img in enumerate(colors):
|
| 103 |
+
img.save(raw_dir / f"color_{idx:02d}.png")
|
| 104 |
+
img.save(edit_dir / f"color_{idx:02d}.png")
|
|
|
|
|
|
|
| 105 |
|
| 106 |
for idx, img in enumerate(normals):
|
| 107 |
+
img.save(raw_dir / f"normal_{idx:02d}.png")
|
| 108 |
+
img.save(edit_dir / f"normal_{idx:02d}.png")
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
import json
|
| 111 |
meta = {
|
| 112 |
"scene": scene,
|
| 113 |
"num_colors": len(colors),
|
|
|
|
| 115 |
"source": "PSHuman two-stage inference",
|
| 116 |
}
|
| 117 |
with open(scene_dir / "meta.json", "w", encoding="utf-8") as f:
|
|
|
|
| 118 |
json.dump(meta, f, indent=2)
|
| 119 |
|
| 120 |
|
|
|
|
| 140 |
return colors, normals
|
| 141 |
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
def load_pshuman_pipeline(cfg):
|
| 144 |
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
| 145 |
cfg.pretrained_model_name_or_path,
|
|
|
|
| 151 |
return pipeline
|
| 152 |
|
| 153 |
|
| 154 |
+
def extract_scene_views_for_case(batch, out, imgs_in, i: int, num_views: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
normals_pred = out[: out.shape[0] // 2]
|
| 156 |
images_pred = out[out.shape[0] // 2:]
|
| 157 |
|
| 158 |
scene = batch['filename'][i].split('.')[0]
|
|
|
|
| 159 |
normals, colors = [], []
|
| 160 |
|
| 161 |
for j in range(num_views):
|
| 162 |
idx = i * num_views + j
|
| 163 |
normal = normals_pred[idx]
|
| 164 |
|
|
|
|
| 165 |
if j == 0:
|
| 166 |
color = imgs_in[i * num_views].to(out.device)
|
| 167 |
else:
|
|
|
|
| 183 |
|
| 184 |
normals.append(normal)
|
| 185 |
|
|
|
|
| 186 |
if len(normals) >= 2:
|
| 187 |
normals[0][:, :256, 256:512] = normals[-1]
|
| 188 |
|
|
|
|
| 189 |
colors_pil = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
|
| 190 |
normals_pil = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
|
| 191 |
|
| 192 |
return scene, colors_pil, normals_pil
|
| 193 |
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
|
| 196 |
+
if pipeline is not None:
|
| 197 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 198 |
|
| 199 |
if cfg.seed is None:
|
| 200 |
generator = None
|
| 201 |
else:
|
| 202 |
+
generator = torch.Generator(device='cuda' if torch.cuda.is_available() else 'cpu').manual_seed(cfg.seed)
|
| 203 |
|
| 204 |
images_cond, pred_cat = [], defaultdict(list)
|
| 205 |
|
| 206 |
for case_id, batch in tqdm(enumerate(dataloader)):
|
| 207 |
images_cond.append(batch['imgs_in'][:, 0])
|
| 208 |
|
|
|
|
| 209 |
if cfg.run_mode == "reconstruct":
|
| 210 |
scene = batch['filename'][0].split('.')[0]
|
| 211 |
colors, normals = load_multiview_scene(
|
|
|
|
| 278 |
vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
|
| 279 |
save_image_tensor(vis_, out_filename)
|
| 280 |
|
|
|
|
| 281 |
continue
|
| 282 |
|
| 283 |
elif cfg.save_mode == 'rgb':
|
|
|
|
| 294 |
save_multiview_scene(cfg.multiview_tmp_dir, scene, colors, normals)
|
| 295 |
continue
|
| 296 |
|
|
|
|
| 297 |
pose = econdata.__getitem__(case_id)
|
| 298 |
carving.optimize_case(scene, pose, colors, normals)
|
| 299 |
torch.cuda.empty_cache()
|
|
|
|
| 303 |
if cfg.seed is not None:
|
| 304 |
set_seed(cfg.seed)
|
| 305 |
|
|
|
|
| 306 |
pipeline = None if cfg.run_mode == "reconstruct" else load_pshuman_pipeline(cfg)
|
| 307 |
|
| 308 |
if cfg.with_smpl:
|