painter3000 commited on
Commit
aa4e85e
·
verified ·
1 Parent(s): 66043e5

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +9 -49
inference.py CHANGED
@@ -1,6 +1,5 @@
1
  import argparse
2
  import os
3
- import shutil
4
  from pathlib import Path
5
  from typing import Dict, Optional, List, Tuple
6
  from collections import defaultdict
@@ -36,10 +35,6 @@ session = new_session(providers=providers)
36
  weight_dtype = torch.float16
37
 
38
 
39
- # ============================================================
40
- # Config
41
- # ============================================================
42
-
43
  @dataclass
44
  class TestConfig:
45
  pretrained_model_name_or_path: str
@@ -63,16 +58,11 @@ class TestConfig:
63
  with_smpl: Optional[bool]
64
  recon_opt: Dict
65
 
66
- # New two-stage fields
67
  run_mode: str = "full" # full | generate | reconstruct
68
  multiview_tmp_dir: str = "./multiview"
69
  prefer_edited_views: bool = True
70
 
71
 
72
- # ============================================================
73
- # Image helpers
74
- # ============================================================
75
-
76
  def convert_to_numpy(tensor):
77
  return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
78
 
@@ -92,10 +82,6 @@ def save_image_tensor(tensor, fp):
92
  return ndarr
93
 
94
 
95
- # ============================================================
96
- # Multiview storage helpers
97
- # ============================================================
98
-
99
  def ensure_dir(path: Path):
100
  path.mkdir(parents=True, exist_ok=True)
101
 
@@ -108,24 +94,20 @@ def save_multiview_scene(multiview_root: str, scene: str, colors: List[Image.Ima
108
  ensure_dir(raw_dir)
109
  ensure_dir(edit_dir)
110
 
111
- # Clean previous files to avoid stale leftovers
112
  for folder in (raw_dir, edit_dir):
113
  for p in folder.glob("*"):
114
  if p.is_file():
115
  p.unlink()
116
 
117
  for idx, img in enumerate(colors):
118
- raw_color = raw_dir / f"color_{idx:02d}.png"
119
- edit_color = edit_dir / f"color_{idx:02d}.png"
120
- img.save(raw_color)
121
- img.save(edit_color)
122
 
123
  for idx, img in enumerate(normals):
124
- raw_normal = raw_dir / f"normal_{idx:02d}.png"
125
- edit_normal = edit_dir / f"normal_{idx:02d}.png"
126
- img.save(raw_normal)
127
- img.save(edit_normal)
128
 
 
129
  meta = {
130
  "scene": scene,
131
  "num_colors": len(colors),
@@ -133,7 +115,6 @@ def save_multiview_scene(multiview_root: str, scene: str, colors: List[Image.Ima
133
  "source": "PSHuman two-stage inference",
134
  }
135
  with open(scene_dir / "meta.json", "w", encoding="utf-8") as f:
136
- import json
137
  json.dump(meta, f, indent=2)
138
 
139
 
@@ -159,10 +140,6 @@ def load_multiview_scene(multiview_root: str, scene: str, prefer_edit=True) -> T
159
  return colors, normals
160
 
161
 
162
- # ============================================================
163
- # Pipeline helpers
164
- # ============================================================
165
-
166
  def load_pshuman_pipeline(cfg):
167
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
168
  cfg.pretrained_model_name_or_path,
@@ -174,25 +151,17 @@ def load_pshuman_pipeline(cfg):
174
  return pipeline
175
 
176
 
177
- def extract_scene_views_for_case(
178
- batch,
179
- out,
180
- imgs_in,
181
- i: int,
182
- num_views: int,
183
- ):
184
  normals_pred = out[: out.shape[0] // 2]
185
  images_pred = out[out.shape[0] // 2:]
186
 
187
  scene = batch['filename'][i].split('.')[0]
188
-
189
  normals, colors = [], []
190
 
191
  for j in range(num_views):
192
  idx = i * num_views + j
193
  normal = normals_pred[idx]
194
 
195
- # Fix from original code: use scene-local first input image
196
  if j == 0:
197
  color = imgs_in[i * num_views].to(out.device)
198
  else:
@@ -214,35 +183,29 @@ def extract_scene_views_for_case(
214
 
215
  normals.append(normal)
216
 
217
- # Preserve original PSHuman behavior
218
  if len(normals) >= 2:
219
  normals[0][:, :256, 256:512] = normals[-1]
220
 
221
- # Original code keeps first 6 views only
222
  colors_pil = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
223
  normals_pil = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
224
 
225
  return scene, colors_pil, normals_pil
226
 
227
 
228
- # ============================================================
229
- # Main inference logic
230
- # ============================================================
231
-
232
  def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
233
- pipeline.set_progress_bar_config(disable=True)
 
234
 
235
  if cfg.seed is None:
236
  generator = None
237
  else:
238
- generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
239
 
240
  images_cond, pred_cat = [], defaultdict(list)
241
 
242
  for case_id, batch in tqdm(enumerate(dataloader)):
243
  images_cond.append(batch['imgs_in'][:, 0])
244
 
245
- # Reconstruct-only path: skip diffusion, load saved views instead
246
  if cfg.run_mode == "reconstruct":
247
  scene = batch['filename'][0].split('.')[0]
248
  colors, normals = load_multiview_scene(
@@ -315,7 +278,6 @@ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save
315
  vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
316
  save_image_tensor(vis_, out_filename)
317
 
318
- # concat mode is only for legacy visualization
319
  continue
320
 
321
  elif cfg.save_mode == 'rgb':
@@ -332,7 +294,6 @@ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save
332
  save_multiview_scene(cfg.multiview_tmp_dir, scene, colors, normals)
333
  continue
334
 
335
- # full mode: original one-pass behavior
336
  pose = econdata.__getitem__(case_id)
337
  carving.optimize_case(scene, pose, colors, normals)
338
  torch.cuda.empty_cache()
@@ -342,7 +303,6 @@ def main(cfg: TestConfig):
342
  if cfg.seed is not None:
343
  set_seed(cfg.seed)
344
 
345
- # Reconstruct mode does not need the diffusion pipeline at all
346
  pipeline = None if cfg.run_mode == "reconstruct" else load_pshuman_pipeline(cfg)
347
 
348
  if cfg.with_smpl:
 
1
  import argparse
2
  import os
 
3
  from pathlib import Path
4
  from typing import Dict, Optional, List, Tuple
5
  from collections import defaultdict
 
35
  weight_dtype = torch.float16
36
 
37
 
 
 
 
 
38
  @dataclass
39
  class TestConfig:
40
  pretrained_model_name_or_path: str
 
58
  with_smpl: Optional[bool]
59
  recon_opt: Dict
60
 
 
61
  run_mode: str = "full" # full | generate | reconstruct
62
  multiview_tmp_dir: str = "./multiview"
63
  prefer_edited_views: bool = True
64
 
65
 
 
 
 
 
66
  def convert_to_numpy(tensor):
67
  return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
68
 
 
82
  return ndarr
83
 
84
 
 
 
 
 
85
  def ensure_dir(path: Path):
86
  path.mkdir(parents=True, exist_ok=True)
87
 
 
94
  ensure_dir(raw_dir)
95
  ensure_dir(edit_dir)
96
 
 
97
  for folder in (raw_dir, edit_dir):
98
  for p in folder.glob("*"):
99
  if p.is_file():
100
  p.unlink()
101
 
102
  for idx, img in enumerate(colors):
103
+ img.save(raw_dir / f"color_{idx:02d}.png")
104
+ img.save(edit_dir / f"color_{idx:02d}.png")
 
 
105
 
106
  for idx, img in enumerate(normals):
107
+ img.save(raw_dir / f"normal_{idx:02d}.png")
108
+ img.save(edit_dir / f"normal_{idx:02d}.png")
 
 
109
 
110
+ import json
111
  meta = {
112
  "scene": scene,
113
  "num_colors": len(colors),
 
115
  "source": "PSHuman two-stage inference",
116
  }
117
  with open(scene_dir / "meta.json", "w", encoding="utf-8") as f:
 
118
  json.dump(meta, f, indent=2)
119
 
120
 
 
140
  return colors, normals
141
 
142
 
 
 
 
 
143
  def load_pshuman_pipeline(cfg):
144
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
145
  cfg.pretrained_model_name_or_path,
 
151
  return pipeline
152
 
153
 
154
+ def extract_scene_views_for_case(batch, out, imgs_in, i: int, num_views: int):
 
 
 
 
 
 
155
  normals_pred = out[: out.shape[0] // 2]
156
  images_pred = out[out.shape[0] // 2:]
157
 
158
  scene = batch['filename'][i].split('.')[0]
 
159
  normals, colors = [], []
160
 
161
  for j in range(num_views):
162
  idx = i * num_views + j
163
  normal = normals_pred[idx]
164
 
 
165
  if j == 0:
166
  color = imgs_in[i * num_views].to(out.device)
167
  else:
 
183
 
184
  normals.append(normal)
185
 
 
186
  if len(normals) >= 2:
187
  normals[0][:, :256, 256:512] = normals[-1]
188
 
 
189
  colors_pil = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
190
  normals_pil = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
191
 
192
  return scene, colors_pil, normals_pil
193
 
194
 
 
 
 
 
195
  def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
196
+ if pipeline is not None:
197
+ pipeline.set_progress_bar_config(disable=True)
198
 
199
  if cfg.seed is None:
200
  generator = None
201
  else:
202
+ generator = torch.Generator(device='cuda' if torch.cuda.is_available() else 'cpu').manual_seed(cfg.seed)
203
 
204
  images_cond, pred_cat = [], defaultdict(list)
205
 
206
  for case_id, batch in tqdm(enumerate(dataloader)):
207
  images_cond.append(batch['imgs_in'][:, 0])
208
 
 
209
  if cfg.run_mode == "reconstruct":
210
  scene = batch['filename'][0].split('.')[0]
211
  colors, normals = load_multiview_scene(
 
278
  vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
279
  save_image_tensor(vis_, out_filename)
280
 
 
281
  continue
282
 
283
  elif cfg.save_mode == 'rgb':
 
294
  save_multiview_scene(cfg.multiview_tmp_dir, scene, colors, normals)
295
  continue
296
 
 
297
  pose = econdata.__getitem__(case_id)
298
  carving.optimize_case(scene, pose, colors, normals)
299
  torch.cuda.empty_cache()
 
303
  if cfg.seed is not None:
304
  set_seed(cfg.seed)
305
 
 
306
  pipeline = None if cfg.run_mode == "reconstruct" else load_pshuman_pipeline(cfg)
307
 
308
  if cfg.with_smpl: