HACO / demo.py
dqj5182's picture
init
5732928
import os
import cv2
import torch
import argparse
import numpy as np
from tqdm import tqdm
import mediapipe as mp
from mediapipe.tasks.python import vision
from mediapipe.tasks.python import BaseOptions
from lib.core.config import cfg, update_config
from lib.models.model import HACO
from lib.utils.human_models import mano
from lib.utils.contact_utils import get_contact_thres
from lib.utils.vis_utils import ContactRenderer, draw_landmarks_on_image
from lib.utils.preprocessing import augmentation_contact
from lib.utils.demo_utils import remove_small_contact_components
parser = argparse.ArgumentParser(description='Demo HACO')
parser.add_argument('--backbone', type=str, default='hamer', choices=['hamer', 'vit-l-16', 'vit-b-16', 'vit-s-16', 'handoccnet', 'hrnet-w48', 'hrnet-w32', 'resnet-152', 'resnet-101', 'resnet-50', 'resnet-34', 'resnet-18'], help='backbone model')
parser.add_argument('--checkpoint', type=str, default='', help='model path for demo')
parser.add_argument('--input_path', type=str, default='asset/example_images', help='image path for demo')
args = parser.parse_args()
# Set device as CUDA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize directories
experiment_dir = 'experiments_demo_image'
# Load config
update_config(backbone_type=args.backbone, exp_dir=experiment_dir)
# Initialize renderer
contact_renderer = ContactRenderer()
# Load demo images
input_dir = args.input_path
images = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# Initialize MediaPipe HandLandmarker
base_options = BaseOptions(model_asset_path=cfg.MODEL.hand_landmarker_path)
hand_options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2)
detector = vision.HandLandmarker.create_from_options(hand_options)
############# Model #############
model = HACO().to(device)
model.eval()
############# Model #############
# Load model checkpoint if provided
if args.checkpoint:
checkpoint = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
############################### Demo Loop ###############################
for i, frame_name in tqdm(enumerate(images), total=len(images)):
print(f"Processing: {frame_name}")
# Load and convert image
frame_path = os.path.join(input_dir, frame_name)
frame = cv2.imread(frame_path)
orig_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_name_base = os.path.splitext(frame_name)[0]
# Hand landmark detection
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=orig_img.copy())
detection_result = detector.detect(mp_image)
annotated_image, right_hand_bbox = draw_landmarks_on_image(orig_img.copy(), detection_result)
if right_hand_bbox is None:
print(f"Skipping {frame_name} - no hand detected.")
continue
print(f"Frame {i}: Right hand bbox: {right_hand_bbox}")
# Image preprocessing
crop_img, img2bb_trans, bb2img_trans, rot, do_flip, color_scale = augmentation_contact(orig_img.copy(), right_hand_bbox, 'test', enforce_flip=False)
# Convert to model input format
if args.backbone in ['handoccnet'] or 'resnet' in cfg.MODEL.backbone_type or 'hrnet' in cfg.MODEL.backbone_type:
from torchvision import transforms
img_tensor = transforms.ToTensor()(crop_img.astype(np.float32) / 255.0)
elif args.backbone in ['hamer'] or 'vit' in cfg.MODEL.backbone_type:
from torchvision.transforms import Normalize
normalize = Normalize(mean=cfg.MODEL.img_mean, std=cfg.MODEL.img_std)
img_tensor = crop_img.transpose(2, 0, 1) / 255.0
img_tensor = normalize(torch.from_numpy(img_tensor)).float()
else:
raise NotImplementedError(f"Unsupported backbone: {args.backbone}")
############# Run model #############
with torch.no_grad():
outputs = model({'input': {'image': img_tensor[None].to(device)}}, mode="test")
############# Run model #############
# Save result
os.makedirs('outputs', exist_ok=True)
os.makedirs('outputs/detection', exist_ok=True)
os.makedirs('outputs/crop_img', exist_ok=True)
os.makedirs('outputs/contact', exist_ok=True)
cv2.imwrite(f'outputs/detection/{frame_name_base}.png', cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR))
cv2.imwrite(f'outputs/crop_img/{frame_name_base}.png', crop_img[..., ::-1])
eval_thres = get_contact_thres(args.backbone)
contact_mask = (outputs['contact_out'][0] > eval_thres).detach().cpu().numpy()
contact_mask = remove_small_contact_components(contact_mask, faces=mano.watertight_face['right'], min_size=20)
contact_rendered = contact_renderer.render_contact(crop_img[..., ::-1], contact_mask)
cv2.imwrite(f'outputs/contact/{frame_name_base}.png', contact_rendered)
############################### Demo Loop ###############################