|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import tqdm |
|
|
import time |
|
|
import torch |
|
|
from dust3r.utils.device import to_cpu, collate_with_cat |
|
|
from dust3r.utils.misc import invalid_to_nans |
|
|
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf |
|
|
|
|
|
|
|
|
def _interleave_imgs(img1, img2): |
|
|
res = {} |
|
|
for key, value1 in img1.items(): |
|
|
value2 = img2[key] |
|
|
if isinstance(value1, torch.Tensor): |
|
|
value = torch.stack((value1, value2), dim=1).flatten(0, 1) |
|
|
else: |
|
|
value = [x for pair in zip(value1, value2) for x in pair] |
|
|
res[key] = value |
|
|
return res |
|
|
|
|
|
|
|
|
def make_batch_symmetric(batch): |
|
|
view1, view2 = batch |
|
|
view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) |
|
|
return view1, view2 |
|
|
|
|
|
def loss_of_one_batch_mv(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None, log = True): |
|
|
views = batch |
|
|
view1, view2s = views[0], views[1:] |
|
|
for view in batch: |
|
|
for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): |
|
|
if name not in view: |
|
|
continue |
|
|
view[name] = view[name].to(device, non_blocking=True) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=bool(use_amp)): |
|
|
t = time.time() |
|
|
pred1, pred2s = model(view1, view2s) |
|
|
nv = len(pred2s) + 1 |
|
|
if log: |
|
|
sync = pred1['pts3d'].mean().item() |
|
|
|
|
|
print('pure inference time (only predictions of pcd and 3DGS parameters, not including cam pose estimations)', time.time() - t, 'nv', nv) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
loss = criterion(view1, view2s, pred1, pred2s, log=log) if criterion is not None else None |
|
|
|
|
|
result = dict(view1=view1, view2s=view2s, pred1=pred1, pred2s=pred2s, loss=loss) |
|
|
return result[ret] if ret else result |
|
|
|
|
|
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None, log = True): |
|
|
if criterion is not None and criterion.mv: |
|
|
return loss_of_one_batch_mv(batch, model, criterion, device, False, use_amp, ret, log) |
|
|
view1, view2 = batch |
|
|
for view in batch: |
|
|
for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): |
|
|
if name not in view: |
|
|
continue |
|
|
view[name] = view[name].to(device, non_blocking=True) |
|
|
|
|
|
if symmetrize_batch: |
|
|
view1, view2 = make_batch_symmetric(batch) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=bool(use_amp)): |
|
|
pred1, pred2 = model(view1, view2) |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
loss = criterion(view1, view2, pred1, pred2, log) if criterion is not None else None |
|
|
|
|
|
result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) |
|
|
return result[ret] if ret else result |
|
|
|
|
|
@torch.no_grad() |
|
|
def inference_mv(batch, model, device, verbose=True): |
|
|
|
|
|
if verbose: |
|
|
print(f'>> Inference with model on {len(batch)} images') |
|
|
|
|
|
result = [] |
|
|
|
|
|
res = loss_of_one_batch_mv(batch, model, None, device, log = True) |
|
|
result.append(to_cpu(res)) |
|
|
|
|
|
result = collate_with_cat(result, lists=False) |
|
|
|
|
|
return result |
|
|
|
|
|
@torch.no_grad() |
|
|
def inference(pairs, model, device, batch_size=8, verbose=True): |
|
|
if verbose: |
|
|
print(f'>> Inference with model on {len(pairs)} image pairs') |
|
|
result = [] |
|
|
|
|
|
|
|
|
multiple_shapes = not (check_if_same_size(pairs)) |
|
|
if multiple_shapes: |
|
|
batch_size = 1 |
|
|
|
|
|
for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose): |
|
|
res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device) |
|
|
result.append(to_cpu(res)) |
|
|
|
|
|
result = collate_with_cat(result, lists=multiple_shapes) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def check_if_same_size(pairs): |
|
|
shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs] |
|
|
shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs] |
|
|
return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2) |
|
|
|
|
|
|
|
|
def get_pred_pts3d(gt, pred, use_pose=False): |
|
|
if 'depth' in pred and 'pseudo_focal' in pred: |
|
|
try: |
|
|
pp = gt['camera_intrinsics'][..., :2, 2] |
|
|
except KeyError: |
|
|
pp = None |
|
|
pts3d = depthmap_to_pts3d(**pred, pp=pp) |
|
|
|
|
|
elif 'pts3d' in pred: |
|
|
|
|
|
pts3d = pred['pts3d'] |
|
|
|
|
|
elif 'pts3d_in_other_view' in pred: |
|
|
|
|
|
assert use_pose is True |
|
|
return pred['pts3d_in_other_view'] |
|
|
|
|
|
if use_pose: |
|
|
camera_pose = pred.get('camera_pose') |
|
|
assert camera_pose is not None |
|
|
pts3d = geotrf(camera_pose, pts3d) |
|
|
|
|
|
return pts3d |
|
|
|
|
|
|
|
|
def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None): |
|
|
assert gt_pts1.ndim == pr_pts1.ndim == 4 |
|
|
assert gt_pts1.shape == pr_pts1.shape |
|
|
if gt_pts2 is not None: |
|
|
assert gt_pts2.ndim == pr_pts2.ndim == 4 |
|
|
assert gt_pts2.shape == pr_pts2.shape |
|
|
|
|
|
|
|
|
nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) |
|
|
nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None |
|
|
|
|
|
pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) |
|
|
pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None |
|
|
|
|
|
all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1 |
|
|
all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 |
|
|
|
|
|
dot_gt_pr = (all_pr * all_gt).sum(dim=-1) |
|
|
dot_gt_gt = all_gt.square().sum(dim=-1) |
|
|
|
|
|
if fit_mode.startswith('avg'): |
|
|
|
|
|
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) |
|
|
elif fit_mode.startswith('median'): |
|
|
scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values |
|
|
elif fit_mode.startswith('weiszfeld'): |
|
|
|
|
|
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) |
|
|
|
|
|
for iter in range(10): |
|
|
|
|
|
dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) |
|
|
|
|
|
w = dis.clip_(min=1e-8).reciprocal() |
|
|
|
|
|
scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) |
|
|
else: |
|
|
raise ValueError(f'bad {fit_mode=}') |
|
|
|
|
|
if fit_mode.endswith('stop_grad'): |
|
|
scaling = scaling.detach() |
|
|
|
|
|
scaling = scaling.clip(min=1e-3) |
|
|
|
|
|
return scaling |
|
|
|