mvdust3r / dust3r /gs.py
ArnoLiu's picture
Upload folder using huggingface_hub
83b6be6 verified
# Copyright (C) 2025-present Meta Platforms, Inc. and affiliates. All rights reserved.
# Licensed under CC BY-NC 4.0 (non-commercial use only).
import torch
import torch.nn as nn
import numpy as np
import imageio
import os
if 'META_INTERNAL' in os.environ.keys() and os.environ['META_INTERNAL'] == "False":
from dust3r.dummy_io import *
else:
from meta_internal.io import *
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
from dust3r.pcd_render import spiral_cam_gen, save_video_combined
from gsplat.rendering import rasterization
from gsplat.rendering import spherical_harmonics
class GaussianRenderer(nn.Module):
def __init__(self, im_height = 224, im_width = 224, znear=0.01, zfar=100.0):
super().__init__()
self.im_height = int(im_height)
self.im_width = int(im_width)
self.znear = znear
self.zfar = zfar
self.register_buffer("bg_color", torch.ones((1, 3), dtype=torch.float32))
def set_view_info(self, height=0, width=0, znear=0.01, zfar=100.0):
self.im_height = int(height)
self.im_width = int(width)
self.znear = znear
self.zfar = zfar
def compute_proj(self, tanfovx, tanfovy):
top = tanfovy * self.znear
bottom = -top
right = tanfovx * self.znear
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * self.znear / (right - left)
P[1, 1] = 2.0 * self.znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * self.zfar / (self.zfar - self.znear)
P[2, 3] = -(self.zfar * self.znear) / (self.zfar - self.znear)
return P
def compute_intrinsic(self, tanfovx, tanfovy):
Ks = torch.eye(3)
Ks[0, 0] = self.im_width / tanfovx / 2.0
Ks[1, 1] = self.im_height / tanfovy / 2.0
Ks[0, 2] = self.im_width / 2.0
Ks[1, 2] = self.im_height / 2.0
return Ks
def calc_color_from_sh(self, pcds, c2ws, sh, sh_degree, debug = False): # pcds: [nc, N, 3], c2ws: [nc, 4, 4], sh: [nc, N, K, 3] -> colors: [nc, N, 3] -1~1
# def spherical_harmonics (
# degrees_to_use: int,
# dirs: Tensor, # [..., 3]
# coeffs: Tensor, # [..., K, 3]
# masks: Optional[Tensor] = None,
# ) -> Tensor:
# """Computes spherical harmonics.
# Args:
# degrees_to_use: The degree to be used.
# dirs: Directions. [..., 3]
# coeffs: Coefficients. [..., K, 3]
# masks: Optional boolen masks to skip some computation. [...,] Default: None.
# Returns:
# Spherical harmonics. [..., 3]
# """
dirs = pcds - c2ws[:, :3, 3][..., None, :] # [nc, N, 3]
colors = spherical_harmonics(sh_degree, dirs, sh)
colors = colors + 0.5
# colors = torch.clamp_min(colors, 0)
colors = colors * 2 - 1
return colors
def forward(self, w2cs, Ks, xyz, rgb, opacity, scale, rotation, eps2d=0.3, SH = False, debug = False): # we assume the input rgb should be -1~1 if it is not sh
if rgb.shape[-1] != 3 or SH:
sh_degree = rgb.shape[-1] // 3
rgb = rgb.reshape(-1, sh_degree, 3)
if rgb.ndim == 2: # in color space -1 ~ 1
sh_degree = None
else:
sh_base = rgb.shape[1]
sh_degree = int(np.sqrt(sh_base)) - 1
# Set up rasterization configuration
extr = w2cs
n_cam = extr.shape[0]
bg_color = self.bg_color.repeat(n_cam, 1).to(extr.device)
out_img, out_alpha, _ = rasterization(
means=xyz,
quats=rotation,
scales=scale,
opacities=opacity,
colors=rgb,
viewmats=extr,
Ks=Ks,
width=self.im_width,
height=self.im_height,
near_plane=self.znear,
far_plane=self.zfar,
backgrounds=bg_color,
sh_degree=sh_degree,
eps2d=eps2d,
)
if sh_degree is not None: # input is SH, return image is 0~1
out_img = out_img * 2 - 1 # original: 4, 512, 512, 3
# print('rgb in GaussianRenderer', rgb.shape, out_img.shape, out_img.max(), out_img.min(), sh_degree)
out_alpha = out_alpha # original: 4, 512, 512, 1
return {"rgb": out_img, "mask": out_alpha}
def gs_render(gts, preds, dp_id_gt, dp_id_pred, c2w_canonical, normalize = False, rot = True, gt_img = False, gt_pcd = False):
# gt1, gt2s, pred1, pred2s = gts[0], gts[1:], preds[0], preds[1:]
# gt_pts1, gt_pts2s, pr_pts1, pr_pts2s, c2ws = torch.load('/home/zgtang/others.pt')
# c2ws = torch.stack([c2w[dp_id] for c2w in c2ws], 0).cuda()
# c2ws = torch.stack([gt1['camera_pose'][dp_id]] + [gt2['camera_pose'][dp_id] for gt2 in gt2s], 0).cuda() # single: [4,4]
intrinsics = torch.stack([gt['camera_intrinsics'][dp_id_gt][:3,:3] for gt in gts]).cuda() # 3,3
rot_gs = torch.stack([pred['rotation'][dp_id_pred] for pred in preds], 0).reshape(-1, 4)
scale_gs = torch.stack([pred['scale'][dp_id_pred] for pred in preds], 0).reshape(-1, 3)
opacity_gs = torch.stack([pred['opacity'][dp_id_pred] for pred in preds], 0).reshape(-1)
if gt_pcd:
pts3d = torch.cat([gt['pts3d'][dp_id_gt] for gt in gts], 0).reshape(-1, 3).cuda() # [224,224,3]
else:
pts3d = torch.cat([preds[0]['pts3d'][dp_id_pred]] + [pred['pts3d_in_other_view'][dp_id_pred] for pred in preds[1:]], 0).reshape(-1, 3).cuda() # [224,224,3]
if gt_img:
imgs = torch.stack([gt['img'][dp_id_gt] for gt in gts], 0).permute(0,2,3,1).reshape(-1, 3).cuda() # single: [3,224,224]
rot_gs = torch.ones_like(rot_gs)
scale_gs = torch.ones_like(scale_gs) * 1e-3
opacity_gs = torch.ones_like(opacity_gs) * 0.5
else:
imgs = torch.stack([pred['rgb'][dp_id_pred] for pred in preds], 0).flatten(0, -2).cuda() # [nv, ]
sh_base = imgs.shape[-1] // 3
imgs = imgs.reshape(-1, sh_base, 3)
# pts3d2 = torch.cat([gt_pts1[dp_id]] + [gt_pts2[dp_id] for gt_pts2 in gt_pts2s], 0).reshape(-1, 3).cuda()
gs = GaussianRenderer()
# def forward(self, w2cs, Ks, xyz, rgb, opacity, scale, rotation):
pts3d_rate = 2.
pcd_range = (torch.max(pts3d, dim = 0)[0] - torch.min(pts3d, dim = 0)[0]).max()
pts3d = pts3d / pcd_range
pts3d = pts3d - torch.mean(pts3d, dim = 0)
pts3d *= pts3d_rate
if not gt_img:
scale_gs = scale_gs / pcd_range * pts3d_rate
scale_range = [0.0001, 0.004]
# scale_range = [0.0001, 0.02]
scale_gs = torch.clamp(scale_gs, scale_range[0], scale_range[1])
w2cs = spiral_cam_gen(imgs.device, 36)
c2ws = torch.linalg.inv(w2cs)
c2ws[:,:,:2] *= -1
R_ = torch.Tensor([
[1, 0, 0],
[0, 0, -1],
[0, 1, 0],
]).to(c2ws)
R = torch.eye(4).to(R_)
R[:3,:3] = R_
# c2ws2 = torch.linalg.inv(c2w_canonical) @ R @ c2ws
# pts3d
c2ws2 = R @ c2ws
if not gt_pcd:
c2w_canonical_inv = torch.linalg.inv(c2w_canonical)
c2w_canonical_inv[:3,3] = 0
c2ws2 = c2w_canonical_inv @ c2ws2
intrinsics = intrinsics[0].repeat(c2ws.shape[0], 1, 1)
res = gs(torch.linalg.inv(c2ws2), intrinsics, pts3d, imgs, opacity_gs, scale_gs, rot_gs, eps2d = 0.1)
# res2 = gs(torch.linalg.inv(c2ws), intrinsics, pts3d2, imgs, torch.ones_like(pts3d[:,0]) * 0.5, torch.ones_like(pts3d) * 0.001, rot)
rgb = res['rgb']
video_frames = [((((rgb[i].detach().cpu() + 1) / 2)).float().numpy() * 256).astype(np.uint8) for i in range(len(rgb))]
# save_video_combined([video_frames], "/home/zgtang/spiral/0.mp4")
return video_frames