import torch import numpy as np from PIL import Image import torchvision.transforms as transforms from diffusers import AutoencoderKL, UniPCMultistepScheduler from transformers import CLIPTextModel, CLIPTokenizer from huggingface_hub import hf_hub_download from mgface.pipelines_mgface.pipeline_mgface import MgPipeline from mgface.pipelines_mgface.unet_ID_2d_condition import UNetID2DConditionModel from mgface.pipelines_mgface.unet_deno_2d_condition import UNetDeno2DConditionModel class MagicFaceModel: def __init__(self, device='cuda'): self.device = device if torch.cuda.is_available() else 'cpu' print(f"🖥️ Initializing MagicFace on: {self.device}") # AU mapping self.ind_dict = { 'AU1':0, 'AU2':1, 'AU4':2, 'AU5':3, 'AU6':4, 'AU9':5, 'AU12':6, 'AU15':7, 'AU17':8, 'AU20':9, 'AU25':10, 'AU26':11 } self.weight_dtype = torch.float16 if self.device == 'cuda' else torch.float32 self.load_models() def load_models(self): """Load all MagicFace components""" print("📥 Loading MagicFace components...") sd_model = 'runwayml/stable-diffusion-v1-5' try: # Load VAE from SD 1.5 print(" - Loading VAE...") self.vae = AutoencoderKL.from_pretrained( sd_model, subfolder="vae", torch_dtype=self.weight_dtype ).to(self.device) # Load Text Encoder from SD 1.5 print(" - Loading Text Encoder...") self.text_encoder = CLIPTextModel.from_pretrained( sd_model, subfolder="text_encoder", torch_dtype=self.weight_dtype ).to(self.device) # Load Tokenizer print(" - Loading Tokenizer...") self.tokenizer = CLIPTokenizer.from_pretrained( sd_model, subfolder="tokenizer", ) # Download YOUR MagicFace model print(" - Downloading YOUR MagicFace model (79999_iter.pth)...") magicface_path = hf_hub_download( repo_id="gauravvjhaa/magicface-affecto-model", filename="79999_iter.pth", cache_dir="./models" ) print(f" - MagicFace model downloaded: {magicface_path}") # Load MagicFace weights print(" - Loading MagicFace weights...") magicface_weights = torch.load(magicface_path, map_location=self.device) # Initialize UNets (you need to define architecture or load from checkpoint) print(" - Initializing ID UNet...") self.unet_ID = UNetID2DConditionModel.from_pretrained( 'mengtingwei/magicface', # Use official architecture subfolder='ID_enc', torch_dtype=self.weight_dtype, use_safetensors=True, low_cpu_mem_usage=True, ).to(self.device) print(" - Initializing Denoising UNet...") self.unet_deno = UNetDeno2DConditionModel.from_pretrained( 'mengtingwei/magicface', # Use official architecture subfolder='denoising_unet', torch_dtype=self.weight_dtype, use_safetensors=True, low_cpu_mem_usage=True, ).to(self.device) # Load YOUR weights into the UNets print(" - Loading YOUR trained weights...") # This depends on how 79999_iter.pth is structured # It might contain both UNets or just one try: # Try loading as state dict if isinstance(magicface_weights, dict): # Check what keys are in the checkpoint print(f" - Checkpoint keys: {list(magicface_weights.keys())[:5]}...") # Load weights (adjust based on actual structure) # Option 1: If it's a full model checkpoint if 'unet_ID' in magicface_weights: self.unet_ID.load_state_dict(magicface_weights['unet_ID']) if 'unet_deno' in magicface_weights: self.unet_deno.load_state_dict(magicface_weights['unet_deno']) # Option 2: If it's just one UNet elif 'state_dict' in magicface_weights: self.unet_deno.load_state_dict(magicface_weights['state_dict']) # Option 3: If it's the raw state dict else: self.unet_deno.load_state_dict(magicface_weights) print(" ✅ YOUR weights loaded successfully!") else: print(" ⚠️ Unexpected checkpoint format, using default weights") except Exception as e: print(f" ⚠️ Could not load your weights: {str(e)}") print(" ⚠️ Using default pretrained weights from mengtingwei/magicface") # Set to eval mode self.vae.requires_grad_(False) self.text_encoder.requires_grad_(False) self.unet_ID.requires_grad_(False) self.unet_deno.requires_grad_(False) self.vae.eval() self.text_encoder.eval() self.unet_ID.eval() self.unet_deno.eval() # Create pipeline print(" - Creating MagicFace pipeline...") self.pipeline = MgPipeline.from_pretrained( sd_model, vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet_ID=self.unet_ID, unet_deno=self.unet_deno, safety_checker=None, torch_dtype=self.weight_dtype, ).to(self.device) # Set scheduler self.pipeline.scheduler = UniPCMultistepScheduler.from_config( self.pipeline.scheduler.config ) self.pipeline.set_progress_bar_config(disable=False) print("✅ MagicFace loaded successfully!") except Exception as e: print(f"❌ Error loading MagicFace: {str(e)}") import traceback traceback.print_exc() raise def tokenize_caption(self, caption: str): """Tokenize text prompt""" inputs = self.tokenizer( caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids.to(self.device) def prepare_au_vector(self, au_params: dict): """Convert AU parameters dict to tensor""" au_prompt = np.zeros((12,)) for au_name, value in au_params.items(): if au_name in self.ind_dict: au_prompt[self.ind_dict[au_name]] = float(value) print(f" 🎭 AU vector: {au_prompt}") return torch.from_numpy(au_prompt).float().unsqueeze(0).to(self.device) @torch.no_grad() def transform(self, source_image, bg_image, au_params, num_inference_steps=50, seed=424): """ Transform facial expression using MagicFace Args: source_image: PIL Image (512x512, cropped face) bg_image: PIL Image (512x512, background) au_params: dict like {"AU6": 2.0, "AU12": 2.0} num_inference_steps: number of diffusion steps seed: random seed Returns: PIL Image (transformed) """ print(f"🎭 Starting MagicFace transformation...") print(f" AU params: {au_params}") print(f" Inference steps: {num_inference_steps}") try: # Prepare inputs transform = transforms.ToTensor() source = transform(source_image).unsqueeze(0).to(self.device) bg = transform(bg_image).unsqueeze(0).to(self.device) # Get text embeddings prompt = "A close up of a person." prompt_ids = self.tokenize_caption(prompt) prompt_embeds = self.text_encoder(prompt_ids)[0] # Prepare AU vector au_vector = self.prepare_au_vector(au_params) # Set seed generator = torch.Generator(device=self.device).manual_seed(seed) # Run inference print(" 🚀 Running diffusion pipeline...") result = self.pipeline( prompt_embeds=prompt_embeds, source=source, bg=bg, au=au_vector, num_inference_steps=num_inference_steps, generator=generator, ) print("✅ Transformation complete!") return result.images[0] except Exception as e: print(f"❌ Transformation error: {str(e)}") import traceback traceback.print_exc() raise