|
|
from copy import copy, deepcopy |
|
|
import cv2 |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from copy import deepcopy |
|
|
from pytorch3d.ops import knn_points |
|
|
from dust3r.utils.geometry import xy_grid |
|
|
|
|
|
from dust3r.inference import get_pred_pts3d, find_opt_scaling |
|
|
from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud, normalize_pointclouds |
|
|
from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale, get_joint_pointcloud_depths, get_joint_pointcloud_center_scales |
|
|
|
|
|
from torch.utils.data import default_collate |
|
|
|
|
|
import random |
|
|
from pytorch3d.transforms import so3_relative_angle |
|
|
|
|
|
def batched_all_pairs(B, N): |
|
|
|
|
|
i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1) |
|
|
i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]] |
|
|
|
|
|
return i1, i2 |
|
|
|
|
|
def closed_form_inverse(se3): |
|
|
""" |
|
|
Computes the inverse of each 4x4 SE3 matrix in the batch. |
|
|
|
|
|
Args: |
|
|
- se3 (Tensor): Nx4x4 tensor of SE3 matrices. |
|
|
|
|
|
Returns: |
|
|
- Tensor: Nx4x4 tensor of inverted SE3 matrices. |
|
|
""" |
|
|
R = se3[:, :3, :3] |
|
|
T = se3[:, 3:, :3] |
|
|
|
|
|
|
|
|
R_transposed = R.transpose(1, 2) |
|
|
|
|
|
|
|
|
left_bottom = -T.bmm(R_transposed) |
|
|
left_combined = torch.cat((R_transposed, left_bottom), dim=1) |
|
|
|
|
|
|
|
|
right_col = se3[:, :, 3:].detach().clone() |
|
|
inverted_matrix = torch.cat((left_combined, right_col), dim=-1) |
|
|
|
|
|
return inverted_matrix |
|
|
|
|
|
def rotation_angle(rot_gt, rot_pred, batch_size=None): |
|
|
|
|
|
try: |
|
|
rel_angle_cos = so3_relative_angle(rot_gt, rot_pred, eps=1e-1) |
|
|
except: |
|
|
R_diff = rot_gt @ rot_pred.transpose(-1, -2) |
|
|
trace_R_diff = R_diff[:,0,0] + R_diff[:,1,1] + R_diff[:,2,2] |
|
|
cos = (trace_R_diff - 1) / 2 |
|
|
cos = torch.clamp(cos, -1 + 1e-3, 1 - 1e-3) |
|
|
rel_angle_cos = torch.acos(cos) |
|
|
rel_rangle_deg = rel_angle_cos * 180 / np.pi |
|
|
|
|
|
if batch_size is not None: |
|
|
rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1) |
|
|
|
|
|
return rel_rangle_deg |
|
|
|
|
|
def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6): |
|
|
"""Normalize the translation vectors and compute the angle between them.""" |
|
|
t_norm = torch.norm(t, dim=1, keepdim=True) |
|
|
t = t / (t_norm + eps) |
|
|
|
|
|
t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True) |
|
|
t_gt = t_gt / (t_gt_norm + eps) |
|
|
|
|
|
loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps) |
|
|
err_t = torch.acos(torch.sqrt(1 - loss_t)) |
|
|
|
|
|
err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err |
|
|
return err_t |
|
|
|
|
|
def translation_angle(tvec_gt, tvec_pred, batch_size=None): |
|
|
|
|
|
rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred) |
|
|
rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi |
|
|
|
|
|
if batch_size is not None: |
|
|
rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1) |
|
|
|
|
|
return rel_tangle_deg |
|
|
|
|
|
def camera_to_rel_deg(pred_se3, gt_se3, device, batch_size): |
|
|
""" |
|
|
Calculate relative rotation and translation angles between predicted and ground truth cameras. |
|
|
|
|
|
Args: |
|
|
- pred_cameras: Predicted camera. |
|
|
- gt_cameras: Ground truth camera. |
|
|
- accelerator: The device for moving tensors to GPU or others. |
|
|
- batch_size: Number of data samples in one batch. |
|
|
|
|
|
Returns: |
|
|
- rel_rotation_angle_deg, rel_translation_angle_deg: Relative rotation and translation angles in degrees. |
|
|
""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pair_idx_i1, pair_idx_i2 = batched_all_pairs(batch_size, gt_se3.shape[0] // batch_size) |
|
|
pair_idx_i1 = pair_idx_i1.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relative_pose_gt = torch.linalg.inv(gt_se3[pair_idx_i1]).bmm(gt_se3[pair_idx_i2]) |
|
|
relative_pose_pred = torch.linalg.inv(pred_se3[pair_idx_i1]).bmm(pred_se3[pair_idx_i2]) |
|
|
|
|
|
|
|
|
|
|
|
rel_rangle_deg = rotation_angle(relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]) |
|
|
rel_tangle_deg = translation_angle(relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]) |
|
|
|
|
|
return rel_rangle_deg, rel_tangle_deg |
|
|
|
|
|
def estimate_focal_knowing_depth(pts3d, valid_mask, min_focal=0., max_focal=np.inf): |
|
|
""" Reprojection method, for when the absolute depth is known: |
|
|
1) estimate the camera focal using a robust estimator |
|
|
2) reproject points onto true rays, minimizing a certain error |
|
|
""" |
|
|
B, H, W, THREE = pts3d.shape |
|
|
assert THREE == 3 |
|
|
|
|
|
|
|
|
pp = torch.tensor([[W/2, H/2]], dtype=torch.float32, device=pts3d.device) |
|
|
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) |
|
|
pts3d = pts3d.flatten(1, 2) |
|
|
valid_mask = valid_mask.flatten(1, 2) |
|
|
pixels = pixels[valid_mask].unsqueeze(0) |
|
|
pts3d = pts3d[valid_mask].unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) |
|
|
|
|
|
dot_xy_px = (xy_over_z * pixels).sum(dim=-1) |
|
|
dot_xy_xy = xy_over_z.square().sum(dim=-1) |
|
|
|
|
|
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) |
|
|
|
|
|
|
|
|
for iter in range(10): |
|
|
|
|
|
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) |
|
|
|
|
|
w = dis.clip(min=1e-8).reciprocal() |
|
|
|
|
|
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) |
|
|
|
|
|
focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) |
|
|
focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) |
|
|
|
|
|
return focal |
|
|
|
|
|
def recursive_concat_collate(batch): |
|
|
if isinstance(batch[0], torch.Tensor): |
|
|
return torch.cat(batch, dim=0) |
|
|
|
|
|
elif isinstance(batch[0], dict): |
|
|
return {key: recursive_concat_collate([d[key] for d in batch]) for key in batch[0]} |
|
|
|
|
|
elif isinstance(batch[0], list): |
|
|
return [recursive_concat_collate([d[i] for d in batch]) for i in range(len(batch[0]))] |
|
|
|
|
|
else: |
|
|
return batch |
|
|
|
|
|
def recursive_repeat_interleave_collate(data, dim = 0, rp = 1): |
|
|
|
|
|
if torch.is_tensor(data): |
|
|
return data.repeat_interleave(rp, dim=dim) |
|
|
elif isinstance(data, dict): |
|
|
return {key: recursive_repeat_interleave_collate(value, dim, rp) for key, value in data.items()} |
|
|
elif isinstance(data, list): |
|
|
return [recursive_repeat_interleave_collate(element, dim, rp) for element in data] |
|
|
elif isinstance(data, tuple): |
|
|
return tuple(recursive_repeat_interleave_collate(element, dim, rp) for element in data) |
|
|
else: |
|
|
return data |
|
|
|
|
|
def combine_dict(dicts, make_list = False): |
|
|
if make_list: |
|
|
dict_all = {k:[] for k in dicts[0].keys()} |
|
|
for dict_i in dicts: |
|
|
for k in dict_i.keys(): |
|
|
dict_all[k].append(dict_i[k]) |
|
|
return dict_all |
|
|
else: |
|
|
dict_all = deepcopy(dicts[0]) |
|
|
for dict_i in dicts[1:]: |
|
|
for k in dict_i.keys(): |
|
|
dict_all[k] = dict_all[k] + dict_i[k] |
|
|
for k in dict_all.keys(): |
|
|
dict_all[k] = dict_all[k] / len(dicts) |
|
|
return dict_all |
|
|
|
|
|
def Sum(*losses_and_masks): |
|
|
loss, mask = losses_and_masks[0] |
|
|
if loss.ndim > 0: |
|
|
|
|
|
return losses_and_masks |
|
|
else: |
|
|
|
|
|
for loss2, mask2 in losses_and_masks[1:]: |
|
|
|
|
|
if isinstance(loss2, list): |
|
|
for loss2_i in loss2: |
|
|
loss = loss + loss2_i |
|
|
else: |
|
|
loss = loss + loss2 |
|
|
return loss |
|
|
|
|
|
def extend_gts(gts, n_ref, bs): |
|
|
gts = recursive_repeat_interleave_collate(gts, 0, n_ref) |
|
|
for data_id in range(bs): |
|
|
for ref_id in range(1, n_ref): |
|
|
for k in gts[0].keys(): |
|
|
recursive_swap(gts[0][k], gts[ref_id][k], data_id * n_ref + ref_id) |
|
|
return gts |
|
|
|
|
|
def swap(a, b): |
|
|
if type(a) == torch.Tensor: |
|
|
return b.clone(), a.clone() |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def swap_ref(a, b): |
|
|
if type(a) == torch.Tensor: |
|
|
temp = a.clone() |
|
|
a[:] = b.clone() |
|
|
b[:] = temp |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def recursive_swap(a, b, pos): |
|
|
|
|
|
if torch.is_tensor(a): |
|
|
a[pos], b[pos] = swap(a[pos], b[pos]) |
|
|
elif isinstance(a, dict): |
|
|
for key in a.keys(): |
|
|
recursive_swap(a[key], b[key], pos) |
|
|
elif isinstance(a, list): |
|
|
for i in range(len(a)): |
|
|
recursive_swap(a[i], b[i], pos) |
|
|
elif isinstance(a, tuple): |
|
|
for i in range(len(a)): |
|
|
recursive_swap(a[i], b[i], pos) |
|
|
else: |
|
|
return |
|
|
|
|
|
def calculate_RRA_RTA(c2w_pred1, c2w_pred2, c2w_gt1, c2w_gt2, eps = 1e-15): |
|
|
""" |
|
|
Return: |
|
|
RRA: [bs,] |
|
|
RTA: [bs,] |
|
|
""" |
|
|
|
|
|
R1 = c2w_pred1[:, :3, :3] |
|
|
R2 = c2w_pred2[:, :3, :3] |
|
|
R1_gt = c2w_gt1[:, :3, :3] |
|
|
R2_gt = c2w_gt2[:, :3, :3] |
|
|
t1 = c2w_pred1[:, :3, 3:] |
|
|
t2 = c2w_pred2[:, :3, 3:] |
|
|
t1_gt = c2w_gt1[:, :3, 3:] |
|
|
t2_gt = c2w_gt2[:, :3, 3:] |
|
|
|
|
|
bs = R1.shape[0] |
|
|
R_pred = R1 @ R2.transpose(-1, -2) |
|
|
R_gt = R1_gt @ R2_gt.transpose(-1, -2) |
|
|
R_diff = R_pred @ R_gt.transpose(-1, -2) |
|
|
|
|
|
P_pred_diff = c2w_pred1 @ torch.linalg.inv(c2w_pred2) |
|
|
P_gt_diff = c2w_gt1 @ torch.linalg.inv(c2w_gt2) |
|
|
|
|
|
trace_R_diff = R_diff[:,0,0] + R_diff[:,1,1] + R_diff[:,2,2] |
|
|
|
|
|
theta = torch.acos((trace_R_diff - 1) / 2) |
|
|
theta[theta > torch.pi] = theta[theta > torch.pi] - 2 * torch.pi |
|
|
theta[theta < -torch.pi] = theta[theta < -torch.pi] + 2 * torch.pi |
|
|
theta = theta * 180 / torch.pi |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_pred = P_pred_diff[:, 3, :3] |
|
|
t_gt = P_gt_diff[:, 3, :3] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_err = 1e6 |
|
|
t_norm = torch.norm(t_pred, dim=1, keepdim=True) |
|
|
t = t_pred / (t_norm + eps) |
|
|
|
|
|
t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True) |
|
|
t_gt = t_gt / (t_gt_norm + eps) |
|
|
|
|
|
loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps) |
|
|
theta_t = torch.acos(torch.sqrt(1 - loss_t)) |
|
|
|
|
|
theta_t[torch.isnan(theta_t) | torch.isinf(theta_t)] = default_err |
|
|
|
|
|
|
|
|
|
|
|
theta_t[theta_t > torch.pi] = theta_t[theta_t > torch.pi] - 2 * torch.pi |
|
|
theta_t[theta_t < -torch.pi] = theta_t[theta_t < -torch.pi] + 2 * torch.pi |
|
|
theta_t = theta_t * 180 / torch.pi |
|
|
|
|
|
return theta.abs(), theta_t.abs() |
|
|
|
|
|
def calibrate_camera_pnpransac(pointclouds, img_points, masks, intrinsics): |
|
|
""" |
|
|
Input: |
|
|
pointclouds: (bs, N, 3) |
|
|
img_points: (bs, N, 2) |
|
|
Return: |
|
|
rotations: (bs, 3, 3) |
|
|
translations: (bs, 3, 1) |
|
|
c2ws: (bs, 4, 4) |
|
|
""" |
|
|
bs = pointclouds.shape[0] |
|
|
|
|
|
camera_matrix = intrinsics.cpu().numpy() |
|
|
|
|
|
dist_coeffs = np.zeros((5, 1)) |
|
|
|
|
|
rotations = [] |
|
|
translations = [] |
|
|
|
|
|
for i in range(bs): |
|
|
obj_points = pointclouds[i][masks[i]].cpu().numpy() |
|
|
img_pts = img_points[i][[masks[i]]].cpu().numpy() |
|
|
|
|
|
success, rvec, tvec, inliers = cv2.solvePnPRansac(obj_points, img_pts, camera_matrix[i], dist_coeffs) |
|
|
|
|
|
if success: |
|
|
rotation_matrix, _ = cv2.Rodrigues(rvec) |
|
|
rotations.append(torch.tensor(rotation_matrix, dtype=torch.float32)) |
|
|
translations.append(torch.tensor(tvec, dtype=torch.float32)) |
|
|
else: |
|
|
rotations.append(torch.eye(3)) |
|
|
translations.append(torch.ones(3, 1)) |
|
|
|
|
|
rotations = torch.stack(rotations).to(pointclouds.device) |
|
|
translations = torch.stack(translations).to(pointclouds.device) |
|
|
w2cs = torch.eye(4).repeat(bs, 1, 1).to(pointclouds.device) |
|
|
w2cs[:, :3, :3] = rotations |
|
|
w2cs[:, :3, 3:] = translations |
|
|
return torch.linalg.inv(w2cs) |
|
|
|
|
|
def umeyama_alignment(P1, P2, mask_): |
|
|
|
|
|
""" |
|
|
Return: |
|
|
R: (bs, 3, 3) |
|
|
sigma: (bs, ) |
|
|
t: (bs, 3) |
|
|
""" |
|
|
from pytorch3d import ops |
|
|
ya = ops.corresponding_points_alignment |
|
|
R, T, s = ya(P1, P2, weights = mask_.float(), estimate_scale = True) |
|
|
return R, s, T |
|
|
|
|
|
bs, _ = P1.shape[0:2] |
|
|
ns = mask_.sum(1) |
|
|
mask = mask_[:,:,None] |
|
|
|
|
|
mu1 = (P1 * mask).sum(1) / ns[:, None] |
|
|
mu2 = (P2 * mask).sum(1) / ns[:, None] |
|
|
|
|
|
X1 = P1 - mu1[:, None] |
|
|
X2 = P2 - mu2[:, None] |
|
|
|
|
|
X1_zero = X1 * mask |
|
|
X2_zero = X2 * mask |
|
|
|
|
|
S = (X2_zero.transpose(-1, -2) @ X1_zero) / (ns[:, None, None] + 1e-8) |
|
|
|
|
|
|
|
|
U, D, Vt = torch.linalg.svd(S) |
|
|
|
|
|
|
|
|
d = torch.ones((bs, 3)).to(P1.device) |
|
|
det = torch.linalg.det(U @ Vt) |
|
|
d[:, -1][det < 0] = -1 |
|
|
|
|
|
D_mat = torch.eye(3).to(P1.device).repeat(bs, 1, 1) |
|
|
D_mat[:, -1, -1] = d[:, -1] |
|
|
|
|
|
|
|
|
R = U @ D_mat @ Vt |
|
|
|
|
|
D_diag = torch.diag_embed(D) |
|
|
|
|
|
|
|
|
var1 = torch.square(X1_zero).sum(dim = (1, 2)) / (ns + 1e-8) |
|
|
sigma = (D_mat @ D_diag).diagonal(dim1 = -2, dim2 = -1).sum(-1) / (var1 + 1e-8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t = mu2 - sigma[:, None] * (R @ mu1[:, :, None])[:, :, 0] |
|
|
|
|
|
return R, sigma, t |
|
|
|
|
|
def chamfer_distance(pts1, pts2, mask): |
|
|
bs = pts1.shape[0] |
|
|
cd = [] |
|
|
for i in range(bs): |
|
|
disAB = knn_points(pts1[i:i+1][mask[i:i+1]][None], pts2[i:i+1][mask[i:i+1]][None])[0].mean() |
|
|
disBA = knn_points(pts2[i:i+1][mask[i:i+1]][None], pts1[i:i+1][mask[i:i+1]][None])[0].mean() |
|
|
cd.append(disAB + disBA) |
|
|
cd = torch.stack(cd, 0) |
|
|
return cd |
|
|
|
|
|
def rotationInvMSE(pts3d_normalized, gts3d_normalized, mask_all): |
|
|
|
|
|
R, sigma, t = umeyama_alignment(pts3d_normalized, gts3d_normalized, mask_all) |
|
|
pts3d_normalized_rot = (sigma[:,None,None] * (R @ pts3d_normalized.transpose(-1, -2)).transpose(-1, -2)) + t[:, None] |
|
|
local_loss = (pts3d_normalized_rot - gts3d_normalized).norm(dim = -1)[mask_all].mean() |
|
|
|
|
|
class LLoss (nn.Module): |
|
|
""" L-norm loss |
|
|
""" |
|
|
|
|
|
def __init__(self, reduction='mean'): |
|
|
super().__init__() |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, a, b, mask = None, reduction=None): |
|
|
if mask is not None: |
|
|
dist = self.distance(a, b) |
|
|
assert reduction == "mean_bs" |
|
|
bs = dist.shape[0] |
|
|
dist_mean = [] |
|
|
for i in range(bs): |
|
|
mask_dist_i = dist[i][mask[i]] |
|
|
if mask_dist_i.numel() > 0: |
|
|
dist_mean.append(mask_dist_i.mean()) |
|
|
else: |
|
|
dist_mean.append(dist.new_zeros(())) |
|
|
dist_mean = torch.stack(dist_mean, 0) |
|
|
return dist_mean |
|
|
|
|
|
assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' |
|
|
dist = self.distance(a, b) |
|
|
assert dist.ndim == a.ndim-1 |
|
|
reduction_effective = self.reduction |
|
|
if reduction is not None: |
|
|
reduction_effective = reduction |
|
|
if reduction_effective == 'none': |
|
|
return dist |
|
|
if reduction_effective == 'sum': |
|
|
return dist.sum() |
|
|
if reduction_effective == 'mean': |
|
|
return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) |
|
|
if reduction_effective == 'mean_bs': |
|
|
bs = dist.shape[0] |
|
|
return (dist.mean([i for i in range(1, dist.ndim)]) if dist.ndim >= 2 else dist) if dist.numel() > 0 else dist.new_zeros((bs,)) |
|
|
raise ValueError(f'bad {self.reduction=} mode') |
|
|
|
|
|
def distance(self, a, b): |
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
class L21Loss (LLoss): |
|
|
""" Euclidean distance between 3d points """ |
|
|
|
|
|
def distance(self, a, b): |
|
|
return torch.norm(a - b, dim=-1) |
|
|
|
|
|
|
|
|
L21 = L21Loss() |
|
|
|
|
|
|
|
|
class Criterion (nn.Module): |
|
|
def __init__(self, criterion=None): |
|
|
super().__init__() |
|
|
assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'+bb() |
|
|
self.criterion = copy(criterion) |
|
|
|
|
|
def get_name(self): |
|
|
return f'{type(self).__name__}({self.criterion})' |
|
|
|
|
|
def with_reduction(self, mode): |
|
|
res = loss = deepcopy(self) |
|
|
while loss is not None: |
|
|
assert isinstance(loss, Criterion) |
|
|
loss.criterion.reduction = 'none' |
|
|
loss = loss._loss2 |
|
|
return res |
|
|
|
|
|
def rearrange_for_mref(self, gt1, gt2s, pred1, pred2s): |
|
|
|
|
|
if gt1['img'].shape[0] == pred1['pts3d'].shape[0]: |
|
|
return gt1, gt2s, pred1, pred2s, 1 |
|
|
bs = gt1['img'].shape[0] |
|
|
bs_pred = pred1['pts3d'].shape[0] |
|
|
n_ref = bs_pred // bs |
|
|
gts = [gt1] + gt2s |
|
|
preds = [pred1] + pred2s |
|
|
|
|
|
gts = extend_gts(gts, n_ref, bs) |
|
|
|
|
|
return gts[0], gts[1:], preds[0], preds[1:], n_ref |
|
|
|
|
|
|
|
|
class MultiLoss (nn.Module): |
|
|
""" Easily combinable losses (also keep track of individual loss values): |
|
|
loss = MyLoss1() + 0.1*MyLoss2() |
|
|
Usage: |
|
|
Inherit from this class and override get_name() and compute_loss() |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self._alpha = 1 |
|
|
self._loss2 = None |
|
|
|
|
|
def compute_loss(self, *args, **kwargs): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def get_name(self): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def __mul__(self, alpha): |
|
|
assert isinstance(alpha, (int, float)) |
|
|
res = copy(self) |
|
|
res._alpha = alpha |
|
|
return res |
|
|
__rmul__ = __mul__ |
|
|
|
|
|
def __add__(self, loss2): |
|
|
assert isinstance(loss2, MultiLoss) |
|
|
res = cur = copy(self) |
|
|
|
|
|
while cur._loss2 is not None: |
|
|
cur = cur._loss2 |
|
|
cur._loss2 = loss2 |
|
|
return res |
|
|
|
|
|
def __repr__(self): |
|
|
name = self.get_name() |
|
|
if self._alpha != 1: |
|
|
name = f'{self._alpha:g}*{name}' |
|
|
if self._loss2: |
|
|
name = f'{name} + {self._loss2}' |
|
|
return name |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
|
loss = self.compute_loss(*args, **kwargs) |
|
|
if isinstance(loss, tuple): |
|
|
loss, details = loss |
|
|
elif loss.ndim == 0: |
|
|
details = {self.get_name(): float(loss)} |
|
|
else: |
|
|
details = {} |
|
|
loss = loss * self._alpha |
|
|
|
|
|
if self._loss2: |
|
|
loss2, details2 = self._loss2(*args, **kwargs) |
|
|
loss = loss + loss2 |
|
|
details |= details2 |
|
|
|
|
|
return loss, details |
|
|
|
|
|
|
|
|
class Regr3D (Criterion, MultiLoss): |
|
|
""" Ensure that all 3D points are correct. |
|
|
Asymmetric loss: view1 is supposed to be the anchor. |
|
|
|
|
|
P1 = RT1 @ D1 |
|
|
P2 = RT2 @ D2 |
|
|
loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1) |
|
|
loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) |
|
|
= (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) |
|
|
""" |
|
|
|
|
|
def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, mv = False, rot_invariant = False, dummy = False): |
|
|
super().__init__(criterion) |
|
|
self.norm_mode = norm_mode |
|
|
self.gt_scale = gt_scale |
|
|
self.mv = mv |
|
|
self.rot_invariant = rot_invariant |
|
|
self.dummy = dummy |
|
|
if mv: |
|
|
self.compute_loss = self.compute_loss_mv |
|
|
|
|
|
def get_all_pts3ds(self, gt1, gt2s, pred1, pred2s, dist_clip=None, **kw): |
|
|
|
|
|
in_camera1 = inv(gt1['camera_pose']) |
|
|
gt_pts1 = geotrf(in_camera1, gt1['pts3d']) |
|
|
gt_pts2s = [geotrf(in_camera1, gt2['pts3d']) for gt2 in gt2s] |
|
|
|
|
|
valid1 = gt1['valid_mask'].clone() |
|
|
valid2s = [gt2['valid_mask'].clone() for gt2 in gt2s] |
|
|
|
|
|
if dist_clip is not None: |
|
|
|
|
|
dis1 = gt_pts1.norm(dim=-1) |
|
|
dis2s = [gt_pts2.norm(dim=-1) for gt_pts2 in gt_pts2s] |
|
|
valid1 = valid1 & (dis1 <= dist_clip) |
|
|
valid2s = [valid2 & (dis2 <= dist_clip) for (valid2, dis2) in zip(valid2s, dis2s)] |
|
|
|
|
|
pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) |
|
|
pr_pts2s = [get_pred_pts3d(gt2, pred2, use_pose=True) for (gt2, pred2) in zip(gt2s, pred2s)] |
|
|
|
|
|
|
|
|
if self.norm_mode: |
|
|
pr_pts1, pr_pts2s = normalize_pointclouds(pr_pts1, pr_pts2s, self.norm_mode, valid1, valid2s) |
|
|
if self.norm_mode and not self.gt_scale: |
|
|
gt_pts1, gt_pts2s = normalize_pointclouds(gt_pts1, gt_pts2s, self.norm_mode, valid1, valid2s) |
|
|
|
|
|
return gt_pts1, gt_pts2s, pr_pts1, pr_pts2s, valid1, valid2s, {} |
|
|
|
|
|
def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None): |
|
|
|
|
|
in_camera1 = inv(gt1['camera_pose']) |
|
|
gt_pts1 = geotrf(in_camera1, gt1['pts3d']) |
|
|
gt_pts2 = geotrf(in_camera1, gt2['pts3d']) |
|
|
|
|
|
valid1 = gt1['valid_mask'].clone() |
|
|
valid2 = gt2['valid_mask'].clone() |
|
|
|
|
|
if dist_clip is not None: |
|
|
|
|
|
dis1 = gt_pts1.norm(dim=-1) |
|
|
dis2 = gt_pts2.norm(dim=-1) |
|
|
valid1 = valid1 & (dis1 <= dist_clip) |
|
|
valid2 = valid2 & (dis2 <= dist_clip) |
|
|
|
|
|
pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) |
|
|
pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True) |
|
|
|
|
|
|
|
|
if self.norm_mode: |
|
|
pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2) |
|
|
if self.norm_mode and not self.gt_scale: |
|
|
gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2) |
|
|
|
|
|
return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {} |
|
|
|
|
|
def compute_loss(self, gt1, gt2, pred1, pred2, **kw): |
|
|
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ |
|
|
self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw) |
|
|
|
|
|
l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) |
|
|
|
|
|
l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2]) |
|
|
self_name = type(self).__name__ |
|
|
details = {self_name+'_pts3d_1': float(l1.mean()), self_name+'_pts3d_2': float(l2.mean())} |
|
|
return Sum((l1, mask1), (l2, mask2)), (details | monitoring) |
|
|
|
|
|
def compute_loss_mv(self, gt1, gt2s_all, pred1, pred2s, log, **kw): |
|
|
C_avg = gt1['C_avg'].mean() |
|
|
gt1, gt2s_all, pred1, pred2s, n_ref = self.rearrange_for_mref(gt1, gt2s_all, pred1, pred2s) |
|
|
num_render_views = gt2s_all[0].get("num_render_views", torch.zeros([0]).long())[0].item() |
|
|
gt2s = gt2s_all[:-num_render_views] if num_render_views else gt2s_all |
|
|
|
|
|
gt_pts1, gt_pts2s, pred_pts1, pred_pts2s, mask1, mask2s, monitoring = \ |
|
|
self.get_all_pts3ds(gt1, gt2s, pred1, pred2s, **kw) |
|
|
|
|
|
nv = len(gt_pts2s) + 1 |
|
|
bs = gt_pts1.shape[0] |
|
|
h = gt_pts1.shape[1] |
|
|
w = gt_pts1.shape[2] |
|
|
|
|
|
|
|
|
if log: |
|
|
pts_thres = 10. |
|
|
gt_pcds_original = [gt1['pts3d']] + [gt2['pts3d'] for gt2 in gt2s] |
|
|
gt_pcds_original = torch.stack(gt_pcds_original, 1) |
|
|
gt_pcds_original = gt_pcds_original.flatten(1, 3) |
|
|
gt_pcds_original_0_1 = torch.quantile(gt_pcds_original, 0.1, dim = 1) |
|
|
gt_pcds_original_0_9 = torch.quantile(gt_pcds_original, 0.9, dim = 1) |
|
|
gt_pcds_original_diff = gt_pcds_original_0_9 - gt_pcds_original_0_1 |
|
|
l1 = self.criterion(pred_pts1, gt_pts1, mask1, reduction='mean_bs') |
|
|
l2s = [self.criterion(pred_pts2, gt_pts2, mask2, reduction='mean_bs') for (gt_pts2, pred_pts2, mask2) in zip(gt_pts2s, pred_pts2s, mask2s)] |
|
|
|
|
|
self_name = type(self).__name__ |
|
|
details = {self_name+'_pts3d_1': float(l1.mean()), self_name+'_pts3d_2': np.mean([float(l2.mean()) for l2 in l2s]).item()} |
|
|
ls = torch.stack([l1.reshape(-1)] + [l2.reshape(-1) for l2 in l2s], -1) |
|
|
ls[ls > pts_thres] = pts_thres |
|
|
ls_mref = ls.reshape(bs // n_ref, n_ref, nv) |
|
|
ls_only_first = ls_mref[:, 0] |
|
|
ls_best = torch.min(ls_mref, dim = 1)[0] |
|
|
|
|
|
details[self_name+'_area0_list'] = (gt_pcds_original_diff[:,1] * gt_pcds_original_diff[:,2]).detach().cpu().tolist() |
|
|
details[self_name+'_area1_list'] = (gt_pcds_original_diff[:,0] * gt_pcds_original_diff[:,2]).detach().cpu().tolist() |
|
|
details[self_name+'_area2_list'] = (gt_pcds_original_diff[:,0] * gt_pcds_original_diff[:,1]).detach().cpu().tolist() |
|
|
details[self_name+'_volume_list'] = (gt_pcds_original_diff[:,2] * gt_pcds_original_diff[:,0] * gt_pcds_original_diff[:,1]).detach().cpu().tolist() |
|
|
details[self_name+'_pts3d_list'] = ls.mean(-1).detach().cpu().tolist() |
|
|
details[self_name+'_pts3d_1_first'] = ls_only_first[:, 0].mean().item() |
|
|
details[self_name+'_pts3d_2_first'] = ls_only_first[:, 1:].mean().item() |
|
|
details[self_name+'_pts3d_first'] = ls_only_first.mean().item() |
|
|
details[self_name+'_pts3d_best'] = ls_best.mean().item() |
|
|
details[self_name+'_pts3d_0.5_accu_list'] = (ls.mean(-1) < 0.5).float().detach().cpu().tolist() |
|
|
details[self_name+'_pts3d_0.3_accu_list'] = (ls.mean(-1) < 0.3).float().detach().cpu().tolist() |
|
|
details[self_name+'_pts3d_0.2_accu_list'] = (ls.mean(-1) < 0.2).float().detach().cpu().tolist() |
|
|
details[self_name+'_pts3d_0.1_accu_list'] = (ls.mean(-1) < 0.1).float().detach().cpu().tolist() |
|
|
details['n_ref'] = n_ref |
|
|
details[self_name+"_C_avg"] = C_avg.item() |
|
|
|
|
|
pred_pts = torch.stack([pred_pts1, *pred_pts2s], 1).flatten(1, 3) |
|
|
gt_pts = torch.stack([gt_pts1, *gt_pts2s], 1).flatten(1, 3) |
|
|
masks = torch.stack([mask1, *mask2s], 1).flatten(1, 3) |
|
|
|
|
|
R, sigma, t = umeyama_alignment(pred_pts, gt_pts, masks) |
|
|
pts3d_normalized_rot = (sigma[:,None,None] * (R @ pred_pts.transpose(-1, -2)).transpose(-1, -2)) + t[:, None] |
|
|
ls_rotInv = (pts3d_normalized_rot - gt_pts).norm(dim = -1).reshape(bs, -1) |
|
|
ls_rotInv[masks][:] = 0. |
|
|
ls_rotInv = ls_rotInv.sum(-1) / (masks.sum(-1) + 1e-8) |
|
|
details[self_name+'_pts3d_rotInv_list'] = ls_rotInv.detach().cpu().tolist() |
|
|
|
|
|
if self.dummy: |
|
|
details[self_name+'_cd_list'] = [0. for i in range(bs)] |
|
|
details[self_name+'_cd'] = 0. |
|
|
details[self_name+'_cd_first'] = 0. |
|
|
details[self_name+'_cd_best'] = 0. |
|
|
else: |
|
|
cd = chamfer_distance(pred_pts, gt_pts, masks) |
|
|
cd[cd > pts_thres] = pts_thres |
|
|
cd_mref = cd.reshape(bs // n_ref, n_ref) |
|
|
details[self_name+'_cd_list'] = cd.detach().cpu().tolist() |
|
|
details[self_name+'_cd'] = cd.mean().item() |
|
|
details[self_name+'_cd_first'] = cd_mref[:, 0].mean().item() |
|
|
details[self_name+'_cd_best'] = torch.min(cd_mref, dim = 1)[0].mean().item() |
|
|
|
|
|
l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) |
|
|
l2s = [self.criterion(pred_pts2[mask2], gt_pts2[mask2]) for (gt_pts2, pred_pts2, mask2) in zip(gt_pts2s, pred_pts2s, mask2s)] |
|
|
|
|
|
else: |
|
|
if self.rot_invariant: |
|
|
pred_pts = torch.stack([pred_pts1, *pred_pts2s], 1).flatten(1,3) |
|
|
gt_pts = torch.stack([gt_pts1, *gt_pts2s], 1).flatten(1,3) |
|
|
mask = torch.stack([mask1, *mask2s], 1).flatten(1,3) |
|
|
|
|
|
R, sigma, t = umeyama_alignment(pred_pts, gt_pts, mask) |
|
|
pts3d_normalized_rot = (sigma[:,None,None] * (R @ pred_pts.transpose(-1, -2)).transpose(-1, -2)) + t[:, None] |
|
|
ls = (pts3d_normalized_rot - gt_pts).norm(dim = -1).reshape(bs, nv, -1) |
|
|
mask = mask.reshape(bs, nv, -1) |
|
|
|
|
|
mask1 = mask[:, 0].reshape(bs, h, w) |
|
|
l1 = ls[:, 0][mask[:, 0]] |
|
|
|
|
|
mask2s = [mask[:, i].reshape(bs, h, w) for i in range(1, nv)] |
|
|
l2s = [ls[:, i][mask[:, i]] for i in range(1, nv)] |
|
|
|
|
|
else: |
|
|
l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) |
|
|
l2s = [self.criterion(pred_pts2[mask2], gt_pts2[mask2]) for (gt_pts2, pred_pts2, mask2) in zip(gt_pts2s, pred_pts2s, mask2s)] |
|
|
|
|
|
details = {} |
|
|
|
|
|
return Sum((l1, mask1), (l2s, mask2s)), (details | monitoring) |
|
|
|
|
|
|
|
|
class ConfLoss (MultiLoss): |
|
|
""" Weighted regression by learned confidence. |
|
|
Assuming the input pixel_loss is a pixel-level regression loss. |
|
|
|
|
|
Principle: |
|
|
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) |
|
|
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) |
|
|
|
|
|
alpha: hyperparameter |
|
|
""" |
|
|
|
|
|
def __init__(self, pixel_loss, alpha=1): |
|
|
super().__init__() |
|
|
assert alpha >= 0 |
|
|
self.alpha = alpha |
|
|
self.pixel_loss = pixel_loss.with_reduction('none') |
|
|
self.mv = self.pixel_loss.mv |
|
|
if self.mv: |
|
|
self.compute_loss = self.compute_loss_mv |
|
|
|
|
|
def get_name(self): |
|
|
return f'ConfLoss({self.pixel_loss})' |
|
|
|
|
|
def get_conf_log(self, x): |
|
|
return x, torch.log(x) |
|
|
|
|
|
def compute_loss_mv(self, gt1, gt2s, pred1, pred2s, **kw): |
|
|
((loss1, msk1), (loss2s, msk2s)), details = self.pixel_loss(gt1, gt2s, pred1, pred2s, **kw) |
|
|
|
|
|
conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) |
|
|
conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 |
|
|
conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 |
|
|
if loss1.numel() == 0: |
|
|
print('NO VALID POINTS in img1', force=True) |
|
|
|
|
|
conf_loss2s = [] |
|
|
for (loss2, msk2, pred2) in zip(loss2s, msk2s, pred2s): |
|
|
if loss2.numel() == 0: |
|
|
print('NO VALID POINTS in img2', force=True) |
|
|
|
|
|
|
|
|
conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) |
|
|
|
|
|
conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 |
|
|
|
|
|
|
|
|
conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 |
|
|
conf_loss2s.append(conf_loss2) |
|
|
conf_loss2 = sum(conf_loss2s) |
|
|
|
|
|
return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) |
|
|
|
|
|
|
|
|
def compute_loss(self, gt1, gt2, pred1, pred2, **kw): |
|
|
|
|
|
((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) |
|
|
if loss1.numel() == 0: |
|
|
print('NO VALID POINTS in img1', force=True) |
|
|
if loss2.numel() == 0: |
|
|
print('NO VALID POINTS in img2', force=True) |
|
|
|
|
|
|
|
|
conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) |
|
|
conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) |
|
|
conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 |
|
|
conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 |
|
|
|
|
|
|
|
|
conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 |
|
|
conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 |
|
|
|
|
|
return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) |
|
|
|
|
|
|
|
|
class Regr3D_ShiftInv (Regr3D): |
|
|
""" Same than Regr3D but invariant to depth shift. |
|
|
""" |
|
|
|
|
|
def get_all_pts3d(self, gt1, gt2, pred1, pred2): |
|
|
|
|
|
if self.mv: |
|
|
return self.get_all_pts3ds(gt1, gt2, pred1, pred2) |
|
|
|
|
|
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ |
|
|
super().get_all_pts3d(gt1, gt2, pred1, pred2) |
|
|
|
|
|
|
|
|
gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2] |
|
|
pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2] |
|
|
gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None] |
|
|
pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None] |
|
|
|
|
|
|
|
|
gt_z1 -= gt_shift_z |
|
|
gt_z2 -= gt_shift_z |
|
|
pred_z1 -= pred_shift_z |
|
|
pred_z2 -= pred_shift_z |
|
|
|
|
|
|
|
|
return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring |
|
|
|
|
|
def get_all_pts3ds(self, gt1, gt2s, pred1, pred2s, **kw): |
|
|
|
|
|
gt_pts1, gt_pts2s, pred_pts1, pred_pts2s, mask1, mask2s, monitoring = \ |
|
|
super().get_all_pts3ds(gt1, gt2s, pred1, pred2s) |
|
|
|
|
|
|
|
|
gt_z1, gt_z2s = gt_pts1[..., 2], [gt_pts2[..., 2] for gt_pts2 in gt_pts2s] |
|
|
pred_z1, pred_z2s = pred_pts1[..., 2], [pred_pts2[..., 2] for pred_pts2 in pred_pts2s] |
|
|
gt_shift_z = get_joint_pointcloud_depths(gt_z1, gt_z2s, mask1, mask2s)[:, None, None] |
|
|
pred_shift_z = get_joint_pointcloud_depths(pred_z1, pred_z2s, mask1, mask2s)[:, None, None] |
|
|
|
|
|
|
|
|
gt_z1 -= gt_shift_z |
|
|
for gt_z2 in gt_z2s: |
|
|
gt_z2 -= gt_shift_z |
|
|
pred_z1 -= pred_shift_z |
|
|
for pred_z2 in pred_z2s: |
|
|
pred_z2 -= pred_shift_z |
|
|
|
|
|
|
|
|
return gt_pts1, gt_pts2s, pred_pts1, pred_pts2s, mask1, mask2s, monitoring |
|
|
|
|
|
class Regr3D_ShiftAllInv (Regr3D): |
|
|
""" Same than Regr3D but invariant to shift of xyz (center to original) |
|
|
""" |
|
|
|
|
|
def get_all_pts3ds(self, gt1, gt2s, pred1, pred2s, **kw): |
|
|
|
|
|
gt_pts1, gt_pts2s, pred_pts1, pred_pts2s, mask1, mask2s, monitoring = \ |
|
|
super().get_all_pts3ds(gt1, gt2s, pred1, pred2s) |
|
|
|
|
|
for coor_id in range(3): |
|
|
|
|
|
gt_z1, gt_z2s = gt_pts1[..., coor_id], [gt_pts2[..., coor_id] for gt_pts2 in gt_pts2s] |
|
|
pred_z1, pred_z2s = pred_pts1[..., coor_id], [pred_pts2[..., coor_id] for pred_pts2 in pred_pts2s] |
|
|
gt_shift_z = get_joint_pointcloud_depths(gt_z1, gt_z2s, mask1, mask2s)[:, None, None] |
|
|
pred_shift_z = get_joint_pointcloud_depths(pred_z1, pred_z2s, mask1, mask2s)[:, None, None] |
|
|
|
|
|
|
|
|
gt_z1 -= gt_shift_z |
|
|
for gt_z2 in gt_z2s: |
|
|
gt_z2 -= gt_shift_z |
|
|
pred_z1 -= pred_shift_z |
|
|
for pred_z2 in pred_z2s: |
|
|
pred_z2 -= pred_shift_z |
|
|
|
|
|
|
|
|
return gt_pts1, gt_pts2s, pred_pts1, pred_pts2s, mask1, mask2s, monitoring |
|
|
|
|
|
|
|
|
class Regr3D_ScaleInv (Regr3D): |
|
|
""" Same than Regr3D but invariant to depth shift. |
|
|
if gt_scale == True: enforce the prediction to take the same scale than GT |
|
|
""" |
|
|
|
|
|
def get_all_pts3d(self, gt1, gt2, pred1, pred2): |
|
|
|
|
|
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2) |
|
|
|
|
|
|
|
|
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2) |
|
|
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2) |
|
|
|
|
|
|
|
|
pred_scale = pred_scale.clip(min=1e-3, max=1e3) |
|
|
|
|
|
|
|
|
if self.gt_scale: |
|
|
pred_pts1 *= gt_scale / pred_scale |
|
|
pred_pts2 *= gt_scale / pred_scale |
|
|
|
|
|
else: |
|
|
gt_pts1 /= gt_scale |
|
|
gt_pts2 /= gt_scale |
|
|
pred_pts1 /= pred_scale |
|
|
pred_pts2 /= pred_scale |
|
|
|
|
|
|
|
|
return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring |
|
|
|
|
|
def get_all_pts3ds(self, gt1, gt2s, pred1, pred2s, **kw): |
|
|
|
|
|
gt_pts1, gt_pts2s, pred_pts1, pred_pts2s, mask1, mask2s, monitoring = super().get_all_pts3ds(gt1, gt2s, pred1, pred2s) |
|
|
|
|
|
|
|
|
_, gt_scale = get_joint_pointcloud_center_scales(gt_pts1, gt_pts2s, mask1, mask2s) |
|
|
_, pred_scale = get_joint_pointcloud_center_scales(pred_pts1, pred_pts2s, mask1, mask2s) |
|
|
|
|
|
|
|
|
pred_scale = pred_scale.clip(min=1e-3, max=1e3) |
|
|
|
|
|
|
|
|
if self.gt_scale: |
|
|
pred_pts1 *= gt_scale / pred_scale |
|
|
for pred_pts2 in pred_pts2s: |
|
|
pred_pts2 *= gt_scale / pred_scale |
|
|
|
|
|
else: |
|
|
gt_pts1 /= gt_scale |
|
|
for gt_pts2 in gt_pts2s: |
|
|
gt_pts2 /= gt_scale |
|
|
pred_pts1 /= pred_scale |
|
|
for pred_pts2 in pred_pts2s: |
|
|
pred_pts2 /= pred_scale |
|
|
|
|
|
|
|
|
return gt_pts1, gt_pts2s, pred_pts1, pred_pts2s, mask1, mask2s, monitoring |
|
|
|
|
|
|
|
|
class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv): |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
class Regr3D_ScaleShiftAllInv (Regr3D_ScaleInv, Regr3D_ShiftAllInv): |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
class CalcMetrics(): |
|
|
|
|
|
def __init__(self, random_crop_size = None, resize = None): |
|
|
from torchmetrics.functional import structural_similarity_index_measure |
|
|
from torchmetrics.image import PeakSignalNoiseRatio |
|
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
|
|
|
|
|
self.psnr = PeakSignalNoiseRatio(data_range=1.0) |
|
|
self.ssim = structural_similarity_index_measure |
|
|
self.lpips = LearnedPerceptualImagePatchSimilarity() |
|
|
|
|
|
from torchvision import transforms |
|
|
self.random_crop = transforms.RandomCrop(random_crop_size) if random_crop_size is not None else None |
|
|
self.resize = transforms.Resize(resize) if resize is not None else None |
|
|
|
|
|
self.laplacian_kernel = torch.tensor([[1, 1, 1], |
|
|
[1, -8, 1], |
|
|
[1, 1, 1]], dtype=torch.float32).view(1, 1, 3, 3) |
|
|
|
|
|
def calc_metrics(self, img_gt, img_, log = True): |
|
|
|
|
|
self.psnr = self.psnr.to(img_.device) |
|
|
self.lpips = self.lpips.to(img_.device) |
|
|
img = torch.clip(img_, -1, 1) |
|
|
|
|
|
results = {} |
|
|
img_01 = (img + 1) / 2 |
|
|
img_gt_01 = (img_gt + 1) / 2 |
|
|
if log: |
|
|
results['psnr'] = float(self.psnr(img_gt_01, img_01).item()) |
|
|
results['ssim'] = float(self.ssim(img_gt_01[None], img_01[None]).item()) |
|
|
results['lpips'] = float(self.lpips(img_gt[None], img[None]).item()) |
|
|
else: |
|
|
results['psnr'] = 0. |
|
|
results['ssim'] = 0. |
|
|
results['lpips'] = 0. |
|
|
return results |
|
|
|
|
|
def calc_lpips(self, img_gt, img): |
|
|
|
|
|
self.lpips = self.lpips.to(img.device) |
|
|
if self.random_crop is not None: |
|
|
all_img = torch.cat([img_gt, img], dim=0) |
|
|
all_img = self.random_crop(all_img) |
|
|
img_gt, img = all_img[:img_gt.shape[0]], all_img[img_gt.shape[0]:] |
|
|
if self.resize is not None: |
|
|
img_gt = self.resize(img_gt) |
|
|
img = self.resize(img) |
|
|
img = torch.clip(img, -1, 1) |
|
|
img_gt = torch.clip(img_gt, -1, 1) |
|
|
return self.lpips(img_gt, img) |
|
|
|
|
|
def laplace(self, img): |
|
|
|
|
|
|
|
|
laplacian_kernel = self.laplacian_kernel.to(img.device) |
|
|
return F.conv2d(img.permute(0,3,1,2), laplacian_kernel, padding=1).permute(0,2,3,1) |
|
|
|
|
|
|
|
|
|
|
|
calc_metrics = CalcMetrics(resize = (224, 224)) |
|
|
|
|
|
class GSRenderLoss (Criterion, MultiLoss): |
|
|
|
|
|
def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, mv = False, render_included = False, scale_scaled = True, use_gt_pcd = False, lpips_coeff = 0., rgb_coeff = 1.0, copy_rgb_coeff = 10.0, use_img_rgb = False, cam_relocation = False, local_loss_coeff = 0., lap_loss_coeff = 0.): |
|
|
super().__init__(criterion) |
|
|
self.norm_mode = norm_mode |
|
|
self.gt_scale = gt_scale |
|
|
self.mv = mv |
|
|
self.compute_loss = None |
|
|
if mv: |
|
|
self.compute_loss = self.compute_loss_mv |
|
|
self.render_included = render_included |
|
|
self.scale_scaled = scale_scaled |
|
|
self.use_gt_pcd = use_gt_pcd |
|
|
self.lpips_coeff = lpips_coeff |
|
|
self.rgb_coeff = rgb_coeff |
|
|
self.copy_rgb_coeff = copy_rgb_coeff |
|
|
self.use_img_rgb = use_img_rgb |
|
|
self.cam_relocation = cam_relocation |
|
|
self.local_loss_coeff = local_loss_coeff |
|
|
self.lap_loss_coeff = lap_loss_coeff |
|
|
|
|
|
from dust3r.gs import GaussianRenderer |
|
|
self.gs_renderer = GaussianRenderer() |
|
|
|
|
|
|
|
|
def get_all_pts3ds_to_canonical_cam(self, gt1, gt2s, pred1, pred2s, log = False, dist_clip=None, **kw): |
|
|
|
|
|
|
|
|
in_camera1 = inv(gt1['camera_pose']) |
|
|
gt_pts1 = geotrf(in_camera1, gt1['pts3d']) |
|
|
gt_pts2s = [geotrf(in_camera1, gt2['pts3d']) for gt2 in gt2s] |
|
|
|
|
|
c2ws = [gt1['camera_pose']] + [gt2['camera_pose'] for gt2 in gt2s] |
|
|
c2ws = [in_camera1 @ c2w for c2w in c2ws] |
|
|
|
|
|
valid1 = gt1['valid_mask'].clone() |
|
|
valid2s = [gt2['valid_mask'].clone() for gt2 in gt2s] |
|
|
|
|
|
if dist_clip is not None: |
|
|
|
|
|
dis1 = gt_pts1.norm(dim=-1) |
|
|
dis2s = [gt_pts2.norm(dim=-1) for gt_pts2 in gt_pts2s] |
|
|
valid1 = valid1 & (dis1 <= dist_clip) |
|
|
valid2s = [valid2 & (dis2 <= dist_clip) for (valid2, dis2) in zip(valid2s, dis2s)] |
|
|
|
|
|
pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) |
|
|
pr_pts2s = [get_pred_pts3d(gt2, pred2, use_pose=True) for (gt2, pred2) in zip(gt2s, pred2s)] |
|
|
|
|
|
|
|
|
if self.norm_mode: |
|
|
pr_pts1, pr_pts2s, pr_norm_factor = normalize_pointclouds(pr_pts1, pr_pts2s, self.norm_mode, valid1, valid2s, return_norm_factor = True) |
|
|
pr_norm_factor = pr_norm_factor.detach() |
|
|
if self.norm_mode and not self.gt_scale: |
|
|
gt_pts1, gt_pts2s, gt_norm_factor = normalize_pointclouds(gt_pts1, gt_pts2s, self.norm_mode, valid1, valid2s, return_norm_factor = True) |
|
|
while gt_norm_factor.ndim < c2ws[0].ndim: |
|
|
gt_norm_factor.unsqueeze_(-1) |
|
|
pr_norm_factor.unsqueeze_(-1) |
|
|
for c2w in c2ws: |
|
|
|
|
|
c2w[:,:3,3:] = c2w[:,:3,3:] / gt_norm_factor |
|
|
if 'scale' in pred1.keys(): |
|
|
while gt_norm_factor.ndim < pred1['scale'].ndim: |
|
|
gt_norm_factor.unsqueeze_(-1) |
|
|
pr_norm_factor.unsqueeze_(-1) |
|
|
for pred in [pred1] + pred2s: |
|
|
if self.scale_scaled: |
|
|
pred['scale'][:] = pred['scale'][:] / pr_norm_factor |
|
|
extra_info = {'monitering': {}} |
|
|
if log: |
|
|
with torch.no_grad(): |
|
|
nv = len(pred2s) + 1 |
|
|
bs, h, w = gt_pts1.shape[0:3] |
|
|
|
|
|
gts = [gt1] + gt2s |
|
|
preds = [pred1] + pred2s |
|
|
pr_pts = [pr_pts1] + pr_pts2s |
|
|
valids = [valid1] + valid2s |
|
|
|
|
|
y_coords, x_coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') |
|
|
pixel_coords = torch.stack([x_coords, y_coords], dim=-1) |
|
|
pixel_coords = pixel_coords.to(gt_pts1.device).repeat(gt_pts1.shape[0], 1, 1, 1).float() |
|
|
|
|
|
conf = preds[0]['conf'].reshape(bs, -1) |
|
|
conf_sorted = conf.sort()[0] |
|
|
conf_thres = conf_sorted[:, int(conf.shape[1] * 0.03)] |
|
|
valid1 = (conf >= conf_thres[:, None]) |
|
|
valid1 = valid1.reshape(bs, h, w) |
|
|
intrinsics = [] |
|
|
pts3d = preds[0]['pts3d'] |
|
|
|
|
|
for i in range(bs): |
|
|
focal_i = estimate_focal_knowing_depth(pts3d[i:i+1], valid1[i:i+1]) |
|
|
intrinsics_i = torch.eye(3, device=pts3d.device) |
|
|
intrinsics_i[0, 0] = focal_i |
|
|
intrinsics_i[1, 1] = focal_i |
|
|
intrinsics_i[0, 2] = w / 2 |
|
|
intrinsics_i[1, 2] = h / 2 |
|
|
intrinsics.append(intrinsics_i) |
|
|
intrinsics = torch.stack(intrinsics, dim=0) |
|
|
|
|
|
|
|
|
for (gt, pr_pt, pred, valid) in zip(gts, pr_pts, preds, valids): |
|
|
|
|
|
gt_intrinsics = gt['camera_intrinsics'][:,:3,:3] |
|
|
|
|
|
|
|
|
|
|
|
if 'c2ws_pred' not in pred.keys(): |
|
|
c2ws_pred = calibrate_camera_pnpransac(pr_pt.flatten(1,2), pixel_coords.flatten(1,2), valid.flatten(1,2), intrinsics) |
|
|
pred['c2ws_pred'] = c2ws_pred |
|
|
else: |
|
|
intrinsics = preds[0]['intrinsics_pred'] |
|
|
c2ws_pred = calibrate_camera_pnpransac(pr_pt.flatten(1,2), pixel_coords.flatten(1,2), valid.flatten(1,2), intrinsics) |
|
|
pred['c2ws_pred'] = c2ws_pred |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c2ws_all = torch.stack(c2ws, dim=1)[:,:nv].flatten(0,1) |
|
|
c2ws_pred = torch.stack([pred['c2ws_pred'] for pred in preds], dim=1)[:,:nv].flatten(0, 1) |
|
|
r, t = camera_to_rel_deg(torch.linalg.inv(c2ws_pred), torch.linalg.inv(c2ws_all), c2ws_all.device, bs) |
|
|
theta_Rs = r.reshape(bs, -1) |
|
|
theta_ts = t.reshape(bs, -1) |
|
|
|
|
|
RRA_thres = 15 |
|
|
RTA_thres = 15 |
|
|
mAA_thres = 30 |
|
|
RRA = (theta_Rs < RRA_thres).float().mean(-1) |
|
|
RTA = (theta_ts < RTA_thres).float().mean(-1) |
|
|
mAA = torch.zeros((bs,)).to(RRA.device) |
|
|
for thres in range(1, mAA_thres + 1): |
|
|
mAA += ((theta_Rs < thres) * (theta_ts < thres)).float().mean(-1) |
|
|
mAA /= mAA_thres |
|
|
extra_info['RRA'] = RRA |
|
|
extra_info['RTA'] = RTA |
|
|
extra_info['mAA'] = mAA |
|
|
|
|
|
return gt_pts1, gt_pts2s, pr_pts1, pr_pts2s, valid1, valid2s, c2ws, pr_norm_factor, extra_info |
|
|
|
|
|
def get_all_pts3ds(self, gt1, gt2s, pred1, pred2s, dist_clip=None, **kw): |
|
|
|
|
|
in_camera1 = inv(gt1['camera_pose']) |
|
|
gt_pts1 = geotrf(in_camera1, gt1['pts3d']) |
|
|
gt_pts2s = [geotrf(in_camera1, gt2['pts3d']) for gt2 in gt2s] |
|
|
|
|
|
valid1 = gt1['valid_mask'].clone() |
|
|
valid2s = [gt2['valid_mask'].clone() for gt2 in gt2s] |
|
|
|
|
|
if dist_clip is not None: |
|
|
|
|
|
dis1 = gt_pts1.norm(dim=-1) |
|
|
dis2s = [gt_pts2.norm(dim=-1) for gt_pts2 in gt_pts2s] |
|
|
valid1 = valid1 & (dis1 <= dist_clip) |
|
|
valid2s = [valid2 & (dis2 <= dist_clip) for (valid2, dis2) in zip(valid2s, dis2s)] |
|
|
|
|
|
pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) |
|
|
pr_pts2s = [get_pred_pts3d(gt2, pred2, use_pose=True) for (gt2, pred2) in zip(gt2s, pred2s)] |
|
|
|
|
|
|
|
|
if self.norm_mode: |
|
|
pr_pts1, pr_pts2s = normalize_pointclouds(pr_pts1, pr_pts2s, self.norm_mode, valid1, valid2s) |
|
|
if self.norm_mode and not self.gt_scale: |
|
|
gt_pts1, gt_pts2s = normalize_pointclouds(gt_pts1, gt_pts2s, self.norm_mode, valid1, valid2s) |
|
|
|
|
|
return gt_pts1, gt_pts2s, pr_pts1, pr_pts2s, valid1, valid2s, {} |
|
|
|
|
|
|
|
|
|
|
|
def local_lap_loss(self, pts3d, gts3d, c2ws_all, mask_all): |
|
|
|
|
|
cam_centers = c2ws_all[:, :3, 3] |
|
|
pts_dis = (pts3d - cam_centers[:,None,None]).norm(dim = -1) |
|
|
gts_dis = (gts3d - cam_centers[:,None,None]).norm(dim = -1) |
|
|
pts_dis_lap = calc_metrics.laplace(pts_dis.unsqueeze(-1)) |
|
|
gts_dis_lap = calc_metrics.laplace(gts_dis.unsqueeze(-1)) |
|
|
lap_loss = (pts_dis_lap - gts_dis_lap).abs().squeeze(-1) |
|
|
lap_loss[~mask_all] = 0 |
|
|
lap_loss = lap_loss.mean() |
|
|
return lap_loss |
|
|
|
|
|
def local_loss(self, pts3d_, gts3d_, c2ws_all_, mask_all_, conf_all_, real_bs): |
|
|
|
|
|
loss_type = "dis" |
|
|
loss_type = "only_T" |
|
|
|
|
|
|
|
|
bs = pts3d_.shape[0] |
|
|
|
|
|
pts3d = pts3d_.clone() |
|
|
gts3d = gts3d_.clone() |
|
|
c2ws_all = c2ws_all_.clone() |
|
|
mask_all = mask_all_.clone() |
|
|
conf_all = conf_all_.clone() |
|
|
|
|
|
|
|
|
pts3d[~mask_all] = 0. |
|
|
gts3d[~mask_all] = 0. |
|
|
|
|
|
if "dis" in loss_type: |
|
|
mask_all_flatten = mask_all.reshape(real_bs, -1) |
|
|
pts3d = pts3d.reshape(real_bs, -1, 3) |
|
|
gts3d = gts3d.reshape(real_bs, -1, 3) |
|
|
id_list_1 = torch.randint(0, pts3d.shape[1], (real_bs, pts3d.shape[1],), device = pts3d.device) |
|
|
valid_1 = mask_all_flatten.gather(1, id_list_1) |
|
|
id_list_2 = torch.randint(0, pts3d.shape[1], (real_bs, pts3d.shape[1],), device = pts3d.device) |
|
|
valid_2 = mask_all_flatten.gather(1, id_list_2) |
|
|
valid = valid_1.bool() & valid_2.bool() |
|
|
pts3d_1 = torch.gather(pts3d, 1, id_list_1[:, :, None].repeat(1, 1, 3)) |
|
|
pts3d_2 = torch.gather(pts3d, 1, id_list_2[:, :, None].repeat(1, 1, 3)) |
|
|
gts3d_1 = torch.gather(gts3d, 1, id_list_1[:, :, None].repeat(1, 1, 3)) |
|
|
gts3d_2 = torch.gather(gts3d, 1, id_list_2[:, :, None].repeat(1, 1, 3)) |
|
|
pts3d_dis = (pts3d_1 - pts3d_2).norm(dim = -1) |
|
|
gts3d_dis = (gts3d_1 - gts3d_2).norm(dim = -1) |
|
|
pts3d_dis = pts3d_dis * valid |
|
|
gts3d_dis = gts3d_dis * valid |
|
|
pts3d_dis_normalized = pts3d_dis / (pts3d_dis.mean(-1, keepdim = True) + 1e-8) |
|
|
gts3d_dis_normalized = gts3d_dis / (gts3d_dis.mean(-1, keepdim = True) + 1e-8) |
|
|
local_loss = (pts3d_dis_normalized - gts3d_dis_normalized).abs().mean() |
|
|
|
|
|
return local_loss |
|
|
|
|
|
pts3d = pts3d.reshape(bs, -1, 3) |
|
|
gts3d = gts3d.reshape(bs, -1, 3) |
|
|
mask_all = mask_all.reshape(bs, -1) |
|
|
n_p = mask_all.sum(-1) |
|
|
pts3d_mean = pts3d.sum(1) / (n_p[:, None] + 1e-8) |
|
|
gts3d_mean = gts3d.sum(1) / (n_p[:, None] + 1e-8) |
|
|
pts3d_center = pts3d - pts3d_mean[:, None] |
|
|
gts3d_center = gts3d - gts3d_mean[:, None] |
|
|
pts3d_norm = pts3d_center.norm(dim = -1) |
|
|
pts3d_norm_mean = pts3d_norm.sum(-1) / (n_p + 1e-8) |
|
|
gts3d_norm = gts3d_center.norm(dim = -1) |
|
|
gts3d_norm_mean = gts3d_norm.sum(-1) / (n_p + 1e-8) |
|
|
pts3d_normalized = pts3d_center / (pts3d_norm_mean[:, None, None] + 1e-8) |
|
|
gts3d_normalized = gts3d_center / (gts3d_norm_mean[:, None, None] + 1e-8) |
|
|
|
|
|
if "only_T" in loss_type: |
|
|
local_loss = (pts3d_normalized - gts3d_normalized).norm(dim = -1).mean() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Return: |
|
|
R: (bs, 3, 3) |
|
|
sigma: (bs, ) |
|
|
t: (bs, 3) |
|
|
""" |
|
|
if "RT" in loss_type: |
|
|
R, sigma, t = umeyama_alignment(pts3d_normalized, gts3d_normalized, mask_all) |
|
|
pts3d_normalized_rot = (sigma[:,None,None] * (R @ pts3d_normalized.transpose(-1, -2)).transpose(-1, -2)) + t[:, None] |
|
|
local_loss = (pts3d_normalized_rot - gts3d_normalized).norm(dim = -1)[mask_all].mean() |
|
|
|
|
|
return local_loss |
|
|
|
|
|
def compute_loss_mv(self, gt1, gt2s, pred1, pred2s, log, scale_range = [0.0001, 0.004], **kw): |
|
|
gt1, gt2s, pred1, pred2s, n_ref = self.rearrange_for_mref(gt1, gt2s, pred1, pred2s) |
|
|
gt_pts1, gt_pts2s, pred_pts1, pred_pts2s, mask1, mask2s, c2ws, pr_norm_factor, extra_info = \ |
|
|
self.get_all_pts3ds_to_canonical_cam(gt1, gt2s, pred1, pred2s, log, **kw) |
|
|
monitoring = extra_info['monitering'] |
|
|
|
|
|
|
|
|
|
|
|
num_render_views = gt2s[0].get("num_render_views", torch.zeros([0]).long())[0].item() |
|
|
gt_pts2s_inference = gt_pts2s[:-num_render_views] if num_render_views else gt_pts2s |
|
|
|
|
|
preds = [pred1] + pred2s |
|
|
gts = [gt1] + gt2s |
|
|
|
|
|
bs, h, w = gt_pts1.shape[:3] |
|
|
nv = len(gt2s) + 1 |
|
|
n_inference = len(pred_pts2s) + 1 |
|
|
mask_all = torch.stack([mask1] + mask2s, 1) |
|
|
conf_all = torch.stack([pred1['conf']] + [pred2['conf'] for pred2 in pred2s], 1) |
|
|
self.gs_renderer.set_view_info(height = h, width = w) |
|
|
|
|
|
loss_rgb = 0. |
|
|
loss_scale = 0. |
|
|
loss_lpips = 0. |
|
|
loss_copy_rgb = 0. |
|
|
loss_local = 0. |
|
|
loss_lap = 0. |
|
|
|
|
|
log_scale = 0. |
|
|
|
|
|
render_all = [] |
|
|
render_relocated_all = [] |
|
|
photometrics_all = [] |
|
|
photometrics_inference_all = [] |
|
|
photometrics_render_all = [] |
|
|
pts3d_all = torch.stack([pred_pts1] + [pred_pts2 for pred_pts2 in pred_pts2s], 1) |
|
|
gts3d_all = torch.stack([gt_pts1] + [gt_pts2 for gt_pts2 in gt_pts2s_inference], 1) |
|
|
c2ws_gs_all = torch.stack([c2w for c2w in c2ws], 1) |
|
|
|
|
|
if self.local_loss_coeff: |
|
|
loss_local = self.local_loss(pts3d_all.flatten(0, 1), gts3d_all.flatten(0, 1), c2ws_gs_all[:,:n_inference].flatten(0, 1), mask_all[:,:n_inference].flatten(0, 1), conf_all.flatten(0, 1), bs) |
|
|
if self.lap_loss_coeff: |
|
|
loss_lap = self.local_lap_loss(pts3d_all.flatten(0, 1), gts3d_all.flatten(0, 1), c2ws_gs_all[:,:n_inference].flatten(0, 1), mask_all[:,:n_inference].flatten(0, 1)) |
|
|
for dp_id in range(bs): |
|
|
|
|
|
|
|
|
|
|
|
c2ws_gs = c2ws_gs_all[dp_id] |
|
|
gt_imgs = torch.stack([gt1['img'][dp_id]] + [gt2['img'][dp_id] for gt2 in gt2s], 0).permute(0,2,3,1) |
|
|
|
|
|
intrinsics = torch.stack([gt1['camera_intrinsics'][dp_id][:3,:3]] + [gt2['camera_intrinsics'][dp_id][:3,:3] for gt2 in gt2s]).cuda() |
|
|
|
|
|
|
|
|
pts3d = pts3d_all[dp_id] |
|
|
gts3d = gts3d_all[dp_id] |
|
|
if self.use_gt_pcd: |
|
|
pts3d = pts3d * 0 + gts3d.detach() |
|
|
|
|
|
pts3d_gs = pts3d.reshape(-1, 3) |
|
|
rgb_gs = torch.stack([pred1['rgb'][dp_id]] + [pred2['rgb'][dp_id] for pred2 in pred2s], 0).flatten(0, -2) |
|
|
rot_gs = torch.stack([pred1['rotation'][dp_id]] + [pred2['rotation'][dp_id] for pred2 in pred2s], 0).reshape(-1, 4) |
|
|
scale_gs = torch.stack([pred1['scale'][dp_id]] + [pred2['scale'][dp_id] for pred2 in pred2s], 0).reshape(-1, 3) |
|
|
|
|
|
scale_clip_above = torch.clip(scale_gs - scale_range[1], min = 0) |
|
|
scale_clip_below = torch.clip(scale_range[0] - scale_gs, min = 0) |
|
|
scale_clip_loss = torch.square(scale_clip_above + scale_clip_below).mean() |
|
|
scale_gs = torch.clip(scale_gs, scale_range[0], scale_range[1]) |
|
|
|
|
|
opacity_gs = torch.stack([pred1['opacity'][dp_id]] + [pred2['opacity'][dp_id] for pred2 in pred2s], 0).reshape(-1) |
|
|
|
|
|
|
|
|
if self.rgb_coeff or dp_id == 0: |
|
|
sh_base = rgb_gs.shape[-1] // 3 |
|
|
SH = False if (self.use_img_rgb and sh_base == 1) else True |
|
|
if self.use_img_rgb: |
|
|
if sh_base == 1: |
|
|
rgb_gs = (rgb_gs[:] * 0).mean() + gt_imgs[:-num_render_views] if num_render_views else gt_imgs |
|
|
rgb_gs = rgb_gs.reshape(-1, 3) |
|
|
else: |
|
|
sh_degree = int(np.sqrt(sh_base)) - 1 |
|
|
pts3d_gs_copy = pts3d.reshape(n_inference, -1, 3) |
|
|
rgb_gs_copy = rgb_gs.reshape(n_inference, -1, sh_base, 3) |
|
|
rgb_copy = self.gs_renderer.calc_color_from_sh(pts3d_gs_copy, c2ws_gs[:n_inference], rgb_gs_copy, sh_degree) |
|
|
|
|
|
l_rgb_copy = self.criterion(rgb_copy.reshape(-1, 3), gt_imgs[:n_inference].reshape(-1, 3)) |
|
|
loss_copy_rgb = loss_copy_rgb + l_rgb_copy |
|
|
if self.cam_relocation: |
|
|
valid_mask = mask_all[dp_id][:n_inference].reshape(n_inference, -1) |
|
|
conf = conf_all[dp_id].reshape(n_inference, -1) |
|
|
conf_sorted = conf.sort()[0] |
|
|
conf_thres = conf_sorted[:, int(conf.shape[1] * 0.03)] |
|
|
valid_mask = valid_mask & (conf >= conf_thres[:, None]) |
|
|
R, sigma, t = umeyama_alignment(pts3d.reshape(n_inference, -1, 3), gts3d.reshape(n_inference, -1, 3), valid_mask) |
|
|
Rt = torch.eye(4).to(R.device).repeat(n_inference, 1, 1) |
|
|
Rt[:, :3, :3] = R |
|
|
Rt[:, :3, 3] = t |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
np_per_view = pts3d_gs.shape[0] // n_inference |
|
|
rgb_render_relocated = [ |
|
|
self.gs_renderer(torch.linalg.inv(c2ws_gs[i:i+1]), intrinsics[i:i+1], pts3d_gs[i * np_per_view: (i + 1) * np_per_view], rgb_gs[i * np_per_view: (i + 1) * np_per_view], opacity_gs[i * np_per_view: (i + 1) * np_per_view], scale_gs[i * np_per_view: (i + 1) * np_per_view], rot_gs[i * np_per_view: (i + 1) * np_per_view], eps2d=0.1, SH = SH)['rgb'] |
|
|
for i in range(n_inference) |
|
|
] |
|
|
|
|
|
rgb_render_relocated = torch.cat(rgb_render_relocated, 0) |
|
|
render_relocated_all.append(rgb_render_relocated.detach().cpu()) |
|
|
|
|
|
res = self.gs_renderer(torch.linalg.inv(c2ws_gs), intrinsics, pts3d_gs, rgb_gs, opacity_gs, scale_gs, rot_gs, eps2d=0.1, SH = SH) |
|
|
|
|
|
rgb_render = res['rgb'] |
|
|
else: |
|
|
rgb_render = torch.zeros_like(gt_imgs) |
|
|
photometric_results = [calc_metrics.calc_metrics(gt_imgs[i].permute(2,0,1), rgb_render[i].permute(2,0,1), log) for i in range(nv)] |
|
|
photometrics_all.append(combine_dict(photometric_results)) |
|
|
photometrics_inference_all.append(combine_dict(photometric_results[:-num_render_views] if num_render_views else photometric_results)) |
|
|
photometrics_render_all.append(combine_dict(photometric_results[-num_render_views:])) |
|
|
|
|
|
log_scale = log_scale + scale_gs.mean() |
|
|
|
|
|
ls = self.criterion(rgb_render[mask_all[dp_id]], gt_imgs[mask_all[dp_id]]) |
|
|
if self.render_included: |
|
|
render_all.append(rgb_render.detach().cpu()) |
|
|
loss_rgb = loss_rgb + ls |
|
|
loss_scale = loss_scale + scale_clip_loss |
|
|
if self.lpips_coeff > 0: |
|
|
loss_lpips = loss_lpips + calc_metrics.calc_lpips(gt_imgs.permute(0,3,1,2), rgb_render.permute(0,3,1,2)) |
|
|
|
|
|
loss_rgb = loss_rgb / bs |
|
|
loss_scale = loss_scale / bs |
|
|
loss_lpips = loss_lpips / bs |
|
|
loss_copy_rgb = loss_copy_rgb / bs |
|
|
|
|
|
log_scale = log_scale / bs |
|
|
|
|
|
loss = self.rgb_coeff * loss_rgb + 1.0 * loss_scale + self.lpips_coeff * loss_lpips + self.copy_rgb_coeff * loss_copy_rgb + self.local_loss_coeff * loss_local + self.lap_loss_coeff * loss_lap |
|
|
|
|
|
self_name = type(self).__name__ |
|
|
details = {} |
|
|
if log: |
|
|
photometrics_all_ = combine_dict(photometrics_all) |
|
|
photometrics_inference_all_ = combine_dict(photometrics_inference_all) |
|
|
photometrics_render_all_ = combine_dict(photometrics_render_all) |
|
|
photometrics_all_list = combine_dict(photometrics_all, make_list=True) |
|
|
photometrics_inference_all_list = combine_dict(photometrics_inference_all, make_list=True) |
|
|
photometrics_render_all_list = combine_dict(photometrics_render_all, make_list=True) |
|
|
details[self_name+'_gs_rgb'] = float(loss_rgb) |
|
|
details[self_name+'_gs_copy_rgb'] = float(loss_copy_rgb) |
|
|
details[self_name+'_gs_scale_clip'] = float(loss_scale) |
|
|
details[self_name+'_gs_scale'] = float(log_scale) |
|
|
details[self_name+'_gs_loss_all'] = float(loss) |
|
|
details[self_name+'_local_loss'] = float(loss_local) |
|
|
details[self_name+'_lap_loss'] = float(loss_lap) |
|
|
details[self_name+'_RRA'] = extra_info['RRA'].mean().item() |
|
|
details[self_name+'_RTA'] = extra_info['RTA'].mean().item() |
|
|
details[self_name+'_mAA'] = extra_info['mAA'].mean().item() |
|
|
details[self_name+'_RRA_list'] = extra_info['RRA'].detach().cpu().tolist() |
|
|
details[self_name+'_RTA_list'] = extra_info['RTA'].detach().cpu().tolist() |
|
|
details[self_name+'_mAA_list'] = extra_info['mAA'].detach().cpu().tolist() |
|
|
|
|
|
for k in photometrics_all_.keys(): |
|
|
details[self_name+f'_gs_all_{k}'] = float(photometrics_all_[k]) |
|
|
details[self_name+f'_gs_inference_{k}'] = float(photometrics_inference_all_[k]) |
|
|
details[self_name+f'_gs_render_{k}'] = float(photometrics_render_all_[k]) |
|
|
details[self_name+f'_gs_all_{k}_list'] = photometrics_all_list[k] |
|
|
details[self_name+f'_gs_inference_{k}_list'] = photometrics_inference_all_list[k] |
|
|
details[self_name+f'_gs_render_{k}_list'] = photometrics_render_all_list[k] |
|
|
if self.render_included: |
|
|
details['render_all'] = render_all |
|
|
if self.cam_relocation: |
|
|
details['render_relocated_all'] = render_relocated_all |
|
|
return loss, (details | monitoring) |
|
|
|