|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import imageio |
|
|
import os |
|
|
|
|
|
if 'META_INTERNAL' in os.environ.keys() and os.environ['META_INTERNAL'] == "False": |
|
|
from dust3r.dummy_io import * |
|
|
else: |
|
|
from meta_internal.io import * |
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__))) |
|
|
|
|
|
from dust3r.pcd_render import spiral_cam_gen, save_video_combined |
|
|
|
|
|
from gsplat.rendering import rasterization |
|
|
from gsplat.rendering import spherical_harmonics |
|
|
|
|
|
class GaussianRenderer(nn.Module): |
|
|
def __init__(self, im_height = 224, im_width = 224, znear=0.01, zfar=100.0): |
|
|
super().__init__() |
|
|
self.im_height = int(im_height) |
|
|
self.im_width = int(im_width) |
|
|
self.znear = znear |
|
|
self.zfar = zfar |
|
|
|
|
|
self.register_buffer("bg_color", torch.ones((1, 3), dtype=torch.float32)) |
|
|
|
|
|
def set_view_info(self, height=0, width=0, znear=0.01, zfar=100.0): |
|
|
self.im_height = int(height) |
|
|
self.im_width = int(width) |
|
|
self.znear = znear |
|
|
self.zfar = zfar |
|
|
|
|
|
def compute_proj(self, tanfovx, tanfovy): |
|
|
top = tanfovy * self.znear |
|
|
bottom = -top |
|
|
right = tanfovx * self.znear |
|
|
left = -right |
|
|
|
|
|
P = torch.zeros(4, 4) |
|
|
z_sign = 1.0 |
|
|
|
|
|
P[0, 0] = 2.0 * self.znear / (right - left) |
|
|
P[1, 1] = 2.0 * self.znear / (top - bottom) |
|
|
P[0, 2] = (right + left) / (right - left) |
|
|
P[1, 2] = (top + bottom) / (top - bottom) |
|
|
P[3, 2] = z_sign |
|
|
P[2, 2] = z_sign * self.zfar / (self.zfar - self.znear) |
|
|
P[2, 3] = -(self.zfar * self.znear) / (self.zfar - self.znear) |
|
|
return P |
|
|
|
|
|
def compute_intrinsic(self, tanfovx, tanfovy): |
|
|
Ks = torch.eye(3) |
|
|
Ks[0, 0] = self.im_width / tanfovx / 2.0 |
|
|
Ks[1, 1] = self.im_height / tanfovy / 2.0 |
|
|
Ks[0, 2] = self.im_width / 2.0 |
|
|
Ks[1, 2] = self.im_height / 2.0 |
|
|
return Ks |
|
|
|
|
|
def calc_color_from_sh(self, pcds, c2ws, sh, sh_degree, debug = False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dirs = pcds - c2ws[:, :3, 3][..., None, :] |
|
|
colors = spherical_harmonics(sh_degree, dirs, sh) |
|
|
colors = colors + 0.5 |
|
|
|
|
|
|
|
|
colors = colors * 2 - 1 |
|
|
return colors |
|
|
|
|
|
def forward(self, w2cs, Ks, xyz, rgb, opacity, scale, rotation, eps2d=0.3, SH = False, debug = False): |
|
|
if rgb.shape[-1] != 3 or SH: |
|
|
sh_degree = rgb.shape[-1] // 3 |
|
|
rgb = rgb.reshape(-1, sh_degree, 3) |
|
|
if rgb.ndim == 2: |
|
|
sh_degree = None |
|
|
else: |
|
|
sh_base = rgb.shape[1] |
|
|
sh_degree = int(np.sqrt(sh_base)) - 1 |
|
|
|
|
|
|
|
|
extr = w2cs |
|
|
n_cam = extr.shape[0] |
|
|
bg_color = self.bg_color.repeat(n_cam, 1).to(extr.device) |
|
|
|
|
|
|
|
|
out_img, out_alpha, _ = rasterization( |
|
|
means=xyz, |
|
|
quats=rotation, |
|
|
scales=scale, |
|
|
opacities=opacity, |
|
|
colors=rgb, |
|
|
viewmats=extr, |
|
|
Ks=Ks, |
|
|
width=self.im_width, |
|
|
height=self.im_height, |
|
|
near_plane=self.znear, |
|
|
far_plane=self.zfar, |
|
|
backgrounds=bg_color, |
|
|
sh_degree=sh_degree, |
|
|
eps2d=eps2d, |
|
|
) |
|
|
if sh_degree is not None: |
|
|
out_img = out_img * 2 - 1 |
|
|
|
|
|
|
|
|
out_alpha = out_alpha |
|
|
|
|
|
return {"rgb": out_img, "mask": out_alpha} |
|
|
|
|
|
def gs_render(gts, preds, dp_id_gt, dp_id_pred, c2w_canonical, normalize = False, rot = True, gt_img = False, gt_pcd = False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
intrinsics = torch.stack([gt['camera_intrinsics'][dp_id_gt][:3,:3] for gt in gts]).cuda() |
|
|
|
|
|
rot_gs = torch.stack([pred['rotation'][dp_id_pred] for pred in preds], 0).reshape(-1, 4) |
|
|
scale_gs = torch.stack([pred['scale'][dp_id_pred] for pred in preds], 0).reshape(-1, 3) |
|
|
opacity_gs = torch.stack([pred['opacity'][dp_id_pred] for pred in preds], 0).reshape(-1) |
|
|
if gt_pcd: |
|
|
pts3d = torch.cat([gt['pts3d'][dp_id_gt] for gt in gts], 0).reshape(-1, 3).cuda() |
|
|
else: |
|
|
pts3d = torch.cat([preds[0]['pts3d'][dp_id_pred]] + [pred['pts3d_in_other_view'][dp_id_pred] for pred in preds[1:]], 0).reshape(-1, 3).cuda() |
|
|
if gt_img: |
|
|
imgs = torch.stack([gt['img'][dp_id_gt] for gt in gts], 0).permute(0,2,3,1).reshape(-1, 3).cuda() |
|
|
rot_gs = torch.ones_like(rot_gs) |
|
|
scale_gs = torch.ones_like(scale_gs) * 1e-3 |
|
|
opacity_gs = torch.ones_like(opacity_gs) * 0.5 |
|
|
else: |
|
|
imgs = torch.stack([pred['rgb'][dp_id_pred] for pred in preds], 0).flatten(0, -2).cuda() |
|
|
sh_base = imgs.shape[-1] // 3 |
|
|
imgs = imgs.reshape(-1, sh_base, 3) |
|
|
|
|
|
|
|
|
gs = GaussianRenderer() |
|
|
|
|
|
pts3d_rate = 2. |
|
|
pcd_range = (torch.max(pts3d, dim = 0)[0] - torch.min(pts3d, dim = 0)[0]).max() |
|
|
pts3d = pts3d / pcd_range |
|
|
pts3d = pts3d - torch.mean(pts3d, dim = 0) |
|
|
pts3d *= pts3d_rate |
|
|
if not gt_img: |
|
|
scale_gs = scale_gs / pcd_range * pts3d_rate |
|
|
scale_range = [0.0001, 0.004] |
|
|
|
|
|
scale_gs = torch.clamp(scale_gs, scale_range[0], scale_range[1]) |
|
|
|
|
|
w2cs = spiral_cam_gen(imgs.device, 36) |
|
|
c2ws = torch.linalg.inv(w2cs) |
|
|
c2ws[:,:,:2] *= -1 |
|
|
|
|
|
R_ = torch.Tensor([ |
|
|
[1, 0, 0], |
|
|
[0, 0, -1], |
|
|
[0, 1, 0], |
|
|
]).to(c2ws) |
|
|
R = torch.eye(4).to(R_) |
|
|
R[:3,:3] = R_ |
|
|
|
|
|
|
|
|
c2ws2 = R @ c2ws |
|
|
if not gt_pcd: |
|
|
c2w_canonical_inv = torch.linalg.inv(c2w_canonical) |
|
|
c2w_canonical_inv[:3,3] = 0 |
|
|
c2ws2 = c2w_canonical_inv @ c2ws2 |
|
|
|
|
|
intrinsics = intrinsics[0].repeat(c2ws.shape[0], 1, 1) |
|
|
res = gs(torch.linalg.inv(c2ws2), intrinsics, pts3d, imgs, opacity_gs, scale_gs, rot_gs, eps2d = 0.1) |
|
|
|
|
|
rgb = res['rgb'] |
|
|
video_frames = [((((rgb[i].detach().cpu() + 1) / 2)).float().numpy() * 256).astype(np.uint8) for i in range(len(rgb))] |
|
|
|
|
|
return video_frames |
|
|
|