|
|
|
|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
import torch |
|
|
import os |
|
|
from packaging import version |
|
|
import huggingface_hub |
|
|
|
|
|
from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape |
|
|
from .heads import head_factory |
|
|
from dust3r.patch_embed import get_patch_embed |
|
|
from dust3r.losses import swap, swap_ref |
|
|
|
|
|
import dust3r.utils.path_to_croco |
|
|
from models.croco import CroCoNet |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
inf = float('inf') |
|
|
|
|
|
hf_version_number = huggingface_hub.__version__ |
|
|
assert version.parse(hf_version_number) >= version.parse("0.22.0"), "Outdated huggingface_hub version, please reinstall requirements.txt" |
|
|
|
|
|
def load_model(model_path, device, verbose=True): |
|
|
if verbose: |
|
|
print('... loading model from', model_path) |
|
|
ckpt = torch.load(model_path, map_location='cpu') |
|
|
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") |
|
|
if 'landscape_only' not in args: |
|
|
args = args[:-1] + ', landscape_only=False)' |
|
|
else: |
|
|
args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') |
|
|
assert "landscape_only=False" in args |
|
|
if verbose: |
|
|
print(f"instantiating : {args}") |
|
|
net = eval(args) |
|
|
s = net.load_state_dict(ckpt['model'], strict=False) |
|
|
if verbose: |
|
|
print(s) |
|
|
return net.to(device) |
|
|
|
|
|
|
|
|
class AsymmetricCroCo3DStereo ( |
|
|
CroCoNet, |
|
|
huggingface_hub.PyTorchModelHubMixin, |
|
|
library_name="dust3r", |
|
|
repo_url="https://github.com/naver/dust3r", |
|
|
tags=["image-to-3d"], |
|
|
): |
|
|
""" Two siamese encoders, followed by two decoders. |
|
|
The goal is to output 3d points directly, both images in view1's frame |
|
|
(hence the asymmetry). |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
output_mode='pts3d', |
|
|
head_type='linear', |
|
|
depth_mode=('exp', -inf, inf), |
|
|
conf_mode=('exp', 1, inf), |
|
|
freeze='none', |
|
|
landscape_only=True, |
|
|
patch_embed_cls='PatchEmbedDust3R', |
|
|
**croco_kwargs): |
|
|
self.patch_embed_cls = patch_embed_cls |
|
|
self.croco_args = fill_default_args(croco_kwargs, super().__init__) |
|
|
super().__init__(**croco_kwargs) |
|
|
|
|
|
|
|
|
self.dec_blocks2 = deepcopy(self.dec_blocks) |
|
|
self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) |
|
|
self.set_freeze(freeze) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kw): |
|
|
if os.path.isfile(pretrained_model_name_or_path): |
|
|
return load_model(pretrained_model_name_or_path, device='cpu') |
|
|
else: |
|
|
return super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw) |
|
|
|
|
|
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): |
|
|
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) |
|
|
|
|
|
def load_state_dict(self, ckpt, **kw): |
|
|
|
|
|
new_ckpt = dict(ckpt) |
|
|
if not any(k.startswith('dec_blocks2') for k in ckpt): |
|
|
for key, value in ckpt.items(): |
|
|
if key.startswith('dec_blocks'): |
|
|
new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value |
|
|
return super().load_state_dict(new_ckpt, **kw) |
|
|
|
|
|
def set_freeze(self, freeze): |
|
|
self.freeze = freeze |
|
|
to_be_frozen = { |
|
|
'none': [], |
|
|
'mask': [self.mask_token], |
|
|
'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], |
|
|
} |
|
|
freeze_all_params(to_be_frozen[freeze]) |
|
|
|
|
|
def _set_prediction_head(self, *args, **kwargs): |
|
|
""" No prediction head """ |
|
|
return |
|
|
|
|
|
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, |
|
|
**kw): |
|
|
assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \ |
|
|
f'{img_size=} must be multiple of {patch_size=}' |
|
|
self.output_mode = output_mode |
|
|
self.head_type = head_type |
|
|
self.depth_mode = depth_mode |
|
|
self.conf_mode = conf_mode |
|
|
|
|
|
self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
|
|
self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
|
|
|
|
|
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) |
|
|
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) |
|
|
|
|
|
def _encode_image(self, image, true_shape): |
|
|
|
|
|
x, pos = self.patch_embed(image, true_shape=true_shape) |
|
|
|
|
|
|
|
|
assert self.enc_pos_embed is None |
|
|
|
|
|
|
|
|
for blk in self.enc_blocks: |
|
|
x = blk(x, pos) |
|
|
|
|
|
x = self.enc_norm(x) |
|
|
return x, pos, None |
|
|
|
|
|
def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): |
|
|
if img1.shape[-2:] == img2.shape[-2:]: |
|
|
out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0), |
|
|
torch.cat((true_shape1, true_shape2), dim=0)) |
|
|
out, out2 = out.chunk(2, dim=0) |
|
|
pos, pos2 = pos.chunk(2, dim=0) |
|
|
else: |
|
|
out, pos, _ = self._encode_image(img1, true_shape1) |
|
|
out2, pos2, _ = self._encode_image(img2, true_shape2) |
|
|
return out, out2, pos, pos2 |
|
|
|
|
|
def _encode_symmetrized(self, view1, view2): |
|
|
img1 = view1['img'] |
|
|
img2 = view2['img'] |
|
|
B = img1.shape[0] |
|
|
|
|
|
shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) |
|
|
shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) |
|
|
|
|
|
|
|
|
if is_symmetrized(view1, view2): |
|
|
|
|
|
feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2]) |
|
|
feat1, feat2 = interleave(feat1, feat2) |
|
|
pos1, pos2 = interleave(pos1, pos2) |
|
|
else: |
|
|
feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2) |
|
|
|
|
|
return (shape1, shape2), (feat1, feat2), (pos1, pos2) |
|
|
|
|
|
def _decoder(self, f1, pos1, f2, pos2): |
|
|
final_output = [(f1, f2)] |
|
|
|
|
|
|
|
|
f1 = self.decoder_embed(f1) |
|
|
f2 = self.decoder_embed(f2) |
|
|
|
|
|
final_output.append((f1, f2)) |
|
|
for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): |
|
|
|
|
|
f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) |
|
|
|
|
|
f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) |
|
|
|
|
|
final_output.append((f1, f2)) |
|
|
|
|
|
|
|
|
del final_output[1] |
|
|
final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) |
|
|
return zip(*final_output) |
|
|
|
|
|
def _downstream_head(self, head_num, decout, img_shape): |
|
|
B, S, D = decout[-1].shape |
|
|
|
|
|
head = getattr(self, f'head{head_num}') |
|
|
return head(decout, img_shape) |
|
|
|
|
|
def forward(self, view1, view2): |
|
|
|
|
|
(shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2) |
|
|
|
|
|
|
|
|
dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) |
|
|
res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) |
|
|
|
|
|
res2['pts3d_in_other_view'] = res2.pop('pts3d') |
|
|
return res1, res2 |
|
|
|
|
|
def except_i(a, i): |
|
|
|
|
|
if i == 0: |
|
|
return a[1:] |
|
|
elif i == len(a) - 1: |
|
|
return a[:-1] |
|
|
if type(a) == list: |
|
|
return a[:i] + a[i+1:] |
|
|
return torch.cat([a[:i], a[i+1:]], dim=0) |
|
|
|
|
|
class AsymmetricCroCo3DStereoMultiView ( |
|
|
CroCoNet, |
|
|
huggingface_hub.PyTorchModelHubMixin, |
|
|
library_name="dust3r", |
|
|
repo_url="https://github.com/naver/dust3r", |
|
|
tags=["image-to-3d"], |
|
|
): |
|
|
""" Two siamese encoders, followed by two decoders. |
|
|
The goal is to output 3d points directly, both images in view1's frame |
|
|
(hence the asymmetry). |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
output_mode='pts3d', |
|
|
head_type='linear', |
|
|
depth_mode=('exp', -inf, inf), |
|
|
conf_mode=('exp', 1, inf), |
|
|
freeze='none', |
|
|
landscape_only=True, |
|
|
patch_embed_cls='PatchEmbedDust3R', |
|
|
GS = False, |
|
|
GS_skip = False, |
|
|
sh_degree = 0, |
|
|
pts_head_config = {}, |
|
|
n_ref = None, |
|
|
**croco_kwargs): |
|
|
self.patch_embed_cls = patch_embed_cls |
|
|
self.croco_args = fill_default_args(croco_kwargs, super().__init__) |
|
|
super().__init__(**croco_kwargs) |
|
|
|
|
|
self.pts_head_config = pts_head_config |
|
|
self.dec_blocks2 = deepcopy(self.dec_blocks) |
|
|
self.GS = GS |
|
|
self.GS_skip = GS_skip |
|
|
self.sh_degree = sh_degree |
|
|
self.n_ref = n_ref |
|
|
self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) |
|
|
self.set_freeze(freeze) |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kw): |
|
|
if os.path.isfile(pretrained_model_name_or_path): |
|
|
return load_model(pretrained_model_name_or_path, device='cpu') |
|
|
else: |
|
|
return super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw) |
|
|
|
|
|
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): |
|
|
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) |
|
|
|
|
|
def load_state_dict(self, ckpt, **kw): |
|
|
|
|
|
new_ckpt = dict(ckpt) |
|
|
if not any(k.startswith('dec_blocks2') for k in ckpt): |
|
|
for key, value in ckpt.items(): |
|
|
if key.startswith('dec_blocks'): |
|
|
new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value |
|
|
return super().load_state_dict(new_ckpt, **kw) |
|
|
|
|
|
def set_freeze(self, freeze): |
|
|
self.freeze = freeze |
|
|
to_be_frozen = { |
|
|
'none': [], |
|
|
'mask': [self.mask_token], |
|
|
'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], |
|
|
} |
|
|
freeze_all_params(to_be_frozen[freeze]) |
|
|
|
|
|
def _set_prediction_head(self, *args, **kwargs): |
|
|
""" No prediction head """ |
|
|
return |
|
|
|
|
|
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, |
|
|
**kw): |
|
|
assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \ |
|
|
f'{img_size=} must be multiple of {patch_size=}' |
|
|
self.output_mode = output_mode |
|
|
self.head_type = head_type |
|
|
self.depth_mode = depth_mode |
|
|
self.conf_mode = conf_mode |
|
|
|
|
|
self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode), pts_head_config = self.pts_head_config) |
|
|
self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode), pts_head_config = self.pts_head_config) |
|
|
|
|
|
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) |
|
|
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) |
|
|
self.GS_head = [None, None] |
|
|
if self.GS: |
|
|
self.downstream_GS_head = nn.ModuleList([head_factory("GSHead", net = self, skip = self.GS_skip, sh_degree = self.sh_degree) for i in range(2)]) |
|
|
self.GS_head = [transpose_to_landscape(self.downstream_GS_head[i], activate=landscape_only) for i in range(2)] |
|
|
|
|
|
def _encode_image(self, image, true_shape): |
|
|
|
|
|
x, pos = self.patch_embed(image, true_shape=true_shape) |
|
|
|
|
|
|
|
|
assert self.enc_pos_embed is None |
|
|
|
|
|
|
|
|
for blk in self.enc_blocks: |
|
|
x = blk(x, pos) |
|
|
|
|
|
x = self.enc_norm(x) |
|
|
return x, pos, None |
|
|
|
|
|
def _encode_image_pairs(self, img1, img2s, true_shape1, true_shape2s): |
|
|
if img1.shape[-2:] == img2s[0].shape[-2:]: |
|
|
n_view = 1 + len(img2s) |
|
|
out, pos, _ = self._encode_image(torch.cat((img1, *img2s), dim=0), |
|
|
torch.cat((true_shape1, *true_shape2s), dim=0)) |
|
|
outs = out.chunk(n_view, dim=0) |
|
|
poss = pos.chunk(n_view, dim=0) |
|
|
out, out2s = outs[0], outs[1:] |
|
|
pos, pos2s = poss[0], poss[1:] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
return out, out2s, pos, pos2s |
|
|
|
|
|
def _encode_symmetrized(self, view1, view2s): |
|
|
img1 = view1['img'] |
|
|
img2s = [view2['img'] for view2 in view2s] |
|
|
B = img1.shape[0] |
|
|
|
|
|
shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) |
|
|
shape2s = [view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) for (view2, img2) in zip(view2s, img2s)] |
|
|
|
|
|
|
|
|
feat1, feat2s, pos1, pos2s = self._encode_image_pairs(img1, img2s, shape1, shape2s) |
|
|
|
|
|
return (shape1, shape2s), (feat1, feat2s), (pos1, pos2s) |
|
|
|
|
|
def _decoder(self, f1, pos1, f2s, pos2s, n_ref = 1): |
|
|
if n_ref > 1: |
|
|
return self._decoder_multi_ref(f1, pos1, f2s, pos2s, n_ref) |
|
|
final_output = [(f1, *f2s)] |
|
|
n_view_src = len(f2s) |
|
|
|
|
|
f1 = self.decoder_embed(f1) |
|
|
bs = f1.shape[0] |
|
|
f2s = torch.cat(f2s, dim = 0) |
|
|
f2s = self.decoder_embed(f2s).split(bs) |
|
|
|
|
|
final_output.append((f1, *f2s)) |
|
|
for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): |
|
|
|
|
|
f1, _ = blk1(final_output[-1][0], final_output[-1][1:], pos1, pos2s, mv = True) |
|
|
|
|
|
f2s = [] |
|
|
for i in range(n_view_src): |
|
|
f2s_other = list(final_output[-1][:1 + i]) + list(final_output[-1][1 + i + 1:]) |
|
|
pos2s_other = [pos1] + list(pos2s[:i]) + list(pos2s[i+1:]) |
|
|
f2s.append(blk2(final_output[-1][1 + i], f2s_other, pos2s[i], pos2s_other, mv = True)[0]) |
|
|
|
|
|
final_output.append((f1, *f2s)) |
|
|
|
|
|
|
|
|
del final_output[1] |
|
|
final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) |
|
|
|
|
|
f1_all = [] |
|
|
f2_alls = [[] for i in range(n_view_src)] |
|
|
for i in range(len(final_output)): |
|
|
f1_all.append(final_output[i][0]) |
|
|
for j in range(n_view_src): |
|
|
f2_alls[j].append(final_output[i][1 + j]) |
|
|
return f1_all, f2_alls |
|
|
|
|
|
def _decoder_multi_ref(self, f1, pos1, f2s, pos2s, n_ref = 1): |
|
|
final_output_mref = [[[f1, *f2s] for i in range(n_ref)]] |
|
|
n_view_src = len(f2s) |
|
|
nv = n_view_src + 1 |
|
|
|
|
|
f1 = self.decoder_embed(f1) |
|
|
bs = f1.shape[0] |
|
|
f2s = torch.cat(f2s, dim = 0) |
|
|
f2s = self.decoder_embed(f2s).split(bs) |
|
|
pos_all = [pos1] + list(pos2s) |
|
|
final_output_mref.append([[f1, *f2s] for i in range(n_ref)]) |
|
|
for blk1, blk2, blk_sv in zip(self.dec_blocks, self.dec_blocks2, self.dec_same_view_blocks): |
|
|
final_output_mref_i = [] |
|
|
for ref_id in range(n_ref): |
|
|
|
|
|
fs = [None for i in range(nv)] |
|
|
f1, _ = blk1(final_output_mref[-1][ref_id][ref_id], except_i(final_output_mref[-1][ref_id], ref_id), pos1, pos2s, mv = True) |
|
|
fs[ref_id] = f1 |
|
|
|
|
|
for other_view_id in range(nv): |
|
|
if other_view_id == ref_id: |
|
|
continue |
|
|
|
|
|
|
|
|
f2 = blk2(final_output_mref[-1][ref_id][other_view_id], except_i(final_output_mref[-1][ref_id], other_view_id), pos1, pos2s, mv = True)[0] |
|
|
fs[other_view_id] = f2 |
|
|
|
|
|
final_output_mref_i.append(fs) |
|
|
fs_new = [[None for i in range(nv)] for j in range(n_ref)] |
|
|
for view_id in range(nv): |
|
|
final_output_mref_i_view = [final_output_mref_i[i][view_id] for i in range(n_ref)] |
|
|
for ref_id in range(n_ref): |
|
|
if blk_sv is not None: |
|
|
fs_new[ref_id][view_id] = blk_sv(final_output_mref_i[ref_id][view_id], except_i(final_output_mref_i_view, ref_id), pos1, pos2s[:n_ref - 1], mv = True, coeff = 1.)[0] |
|
|
else: |
|
|
fs_new[ref_id][view_id] = final_output_mref_i[ref_id][view_id] |
|
|
|
|
|
final_output_mref.append(fs_new) |
|
|
|
|
|
|
|
|
del final_output_mref[1] |
|
|
final_output = final_output_mref |
|
|
|
|
|
final_output_last = [] |
|
|
for view_id in range(nv): |
|
|
final_output_last_view = torch.stack([final_output[-1][i][view_id] for i in range(n_ref)], dim = 1) |
|
|
final_output_last.append(final_output_last_view.flatten(0, 1)) |
|
|
final_output[-1] = tuple(final_output_last) |
|
|
|
|
|
final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) |
|
|
|
|
|
for data_id in range(bs): |
|
|
for ref_id in range(1, n_ref): |
|
|
swap_ref(final_output[-1][0][data_id * n_ref + ref_id], final_output[-1][ref_id][data_id * n_ref + ref_id]) |
|
|
|
|
|
final_output = final_output[-1:] |
|
|
f1_all = [] |
|
|
f2_alls = [[] for i in range(n_view_src)] |
|
|
for i in range(len(final_output)): |
|
|
f1_all.append(final_output[i][0]) |
|
|
for j in range(n_view_src): |
|
|
f2_alls[j].append(final_output[i][1 + j]) |
|
|
|
|
|
return f1_all, f2_alls |
|
|
|
|
|
def _downstream_head(self, head_num, decout, img_shape): |
|
|
|
|
|
|
|
|
head = getattr(self, f'head{head_num}') |
|
|
return head(decout, img_shape) |
|
|
|
|
|
def _downstream_head_GS(self, head_num, decout, img_shape): |
|
|
|
|
|
|
|
|
head = self.GS_head[head_num - 1] |
|
|
return head(decout, img_shape) |
|
|
|
|
|
def forward(self, view1, view2s_all): |
|
|
|
|
|
num_render_views = view2s_all[0].get("num_render_views", torch.Tensor([0]).long())[0].item() |
|
|
n_ref = view2s_all[0].get("n_ref", torch.Tensor([1]).long())[0].item() |
|
|
if self.n_ref is not None: |
|
|
n_ref = self.n_ref |
|
|
assert self.m_ref_flag == False or (self.m_ref_flag == True and n_ref > 1), f"No. of reference views should be > 1 if m_ref_flag is True" |
|
|
|
|
|
if num_render_views: |
|
|
view2s, view2s_render = view2s_all[:-num_render_views], view2s_all[-num_render_views:] |
|
|
else: |
|
|
view2s, view2s_render = view2s_all, [] |
|
|
|
|
|
(shape1, shape2s), (feat1, feat2s), (pos1, pos2s) = self._encode_symmetrized(view1, view2s) |
|
|
|
|
|
|
|
|
dec1, dec2s = self._decoder(feat1, pos1, feat2s, pos2s, n_ref = n_ref) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
|
|
|
|
|
|
bs = view1['img'].shape[0] |
|
|
view1_img = view1['img'].repeat_interleave(n_ref, dim = 0) |
|
|
view2s_img = [view2['img'].repeat_interleave(n_ref, dim = 0) for view2 in view2s] |
|
|
|
|
|
views_img = [view1_img] + view2s_img |
|
|
for data_id in range(bs): |
|
|
for ref_id in range(1, n_ref): |
|
|
swap_ref(views_img[0][data_id * n_ref + ref_id], views_img[ref_id][data_id * n_ref + ref_id]) |
|
|
view1_img = views_img[0] |
|
|
view2s_img = views_img[1:] |
|
|
|
|
|
res1 = self._downstream_head(1, ([tok.float() for tok in dec1], view1_img), shape1) |
|
|
res2s = [self._downstream_head(2, ([tok.float() for tok in dec2], view2_img), shape2) for (dec2, shape2, view2_img) in zip(dec2s, shape2s, view2s_img)] |
|
|
if self.GS: |
|
|
res1_GS = self._downstream_head_GS(1, ([tok.float() for tok in dec1], view1_img), shape1) |
|
|
res2s_GS = [self._downstream_head_GS(2, ([tok.float() for tok in dec2], view2_img), shape2) for (dec2, shape2, view2_img) in zip(dec2s, shape2s, view2s_img)] |
|
|
res1 = {**res1, **res1_GS} |
|
|
res2s_new = [] |
|
|
for (res2, res2_GS) in zip(res2s, res2s_GS): |
|
|
res2 = {**res2, **res2_GS} |
|
|
res2s_new.append(res2) |
|
|
res2s = res2s_new |
|
|
|
|
|
for res2 in res2s: |
|
|
res2['pts3d_in_other_view'] = res2.pop('pts3d') |
|
|
|
|
|
return res1, res2s |