Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from diffusers import AutoencoderKL, UniPCMultistepScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from huggingface_hub import hf_hub_download | |
| from mgface.pipelines_mgface.pipeline_mgface import MgPipeline | |
| from mgface.pipelines_mgface.unet_ID_2d_condition import UNetID2DConditionModel | |
| from mgface.pipelines_mgface.unet_deno_2d_condition import UNetDeno2DConditionModel | |
| class MagicFaceModel: | |
| def __init__(self, device='cuda'): | |
| self.device = device if torch.cuda.is_available() else 'cpu' | |
| print(f"π₯οΈ Initializing MagicFace on: {self.device}") | |
| # AU mapping | |
| self.ind_dict = { | |
| 'AU1':0, 'AU2':1, 'AU4':2, 'AU5':3, 'AU6':4, 'AU9':5, | |
| 'AU12':6, 'AU15':7, 'AU17':8, 'AU20':9, 'AU25':10, 'AU26':11 | |
| } | |
| self.weight_dtype = torch.float16 if self.device == 'cuda' else torch.float32 | |
| self.load_models() | |
| def load_models(self): | |
| """Load all MagicFace components""" | |
| print("π₯ Loading MagicFace components...") | |
| sd_model = 'runwayml/stable-diffusion-v1-5' | |
| try: | |
| # Load VAE from SD 1.5 | |
| print(" - Loading VAE...") | |
| self.vae = AutoencoderKL.from_pretrained( | |
| sd_model, | |
| subfolder="vae", | |
| torch_dtype=self.weight_dtype | |
| ).to(self.device) | |
| # Load Text Encoder from SD 1.5 | |
| print(" - Loading Text Encoder...") | |
| self.text_encoder = CLIPTextModel.from_pretrained( | |
| sd_model, | |
| subfolder="text_encoder", | |
| torch_dtype=self.weight_dtype | |
| ).to(self.device) | |
| # Load Tokenizer | |
| print(" - Loading Tokenizer...") | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| sd_model, | |
| subfolder="tokenizer", | |
| ) | |
| # Download YOUR MagicFace model | |
| print(" - Downloading YOUR MagicFace model (79999_iter.pth)...") | |
| magicface_path = hf_hub_download( | |
| repo_id="gauravvjhaa/magicface-affecto-model", | |
| filename="79999_iter.pth", | |
| cache_dir="./models" | |
| ) | |
| print(f" - MagicFace model downloaded: {magicface_path}") | |
| # Load MagicFace weights | |
| print(" - Loading MagicFace weights...") | |
| magicface_weights = torch.load(magicface_path, map_location=self.device) | |
| # Initialize UNets (you need to define architecture or load from checkpoint) | |
| print(" - Initializing ID UNet...") | |
| self.unet_ID = UNetID2DConditionModel.from_pretrained( | |
| 'mengtingwei/magicface', # Use official architecture | |
| subfolder='ID_enc', | |
| torch_dtype=self.weight_dtype, | |
| use_safetensors=True, | |
| low_cpu_mem_usage=True, | |
| ).to(self.device) | |
| print(" - Initializing Denoising UNet...") | |
| self.unet_deno = UNetDeno2DConditionModel.from_pretrained( | |
| 'mengtingwei/magicface', # Use official architecture | |
| subfolder='denoising_unet', | |
| torch_dtype=self.weight_dtype, | |
| use_safetensors=True, | |
| low_cpu_mem_usage=True, | |
| ).to(self.device) | |
| # Load YOUR weights into the UNets | |
| print(" - Loading YOUR trained weights...") | |
| # This depends on how 79999_iter.pth is structured | |
| # It might contain both UNets or just one | |
| try: | |
| # Try loading as state dict | |
| if isinstance(magicface_weights, dict): | |
| # Check what keys are in the checkpoint | |
| print(f" - Checkpoint keys: {list(magicface_weights.keys())[:5]}...") | |
| # Load weights (adjust based on actual structure) | |
| # Option 1: If it's a full model checkpoint | |
| if 'unet_ID' in magicface_weights: | |
| self.unet_ID.load_state_dict(magicface_weights['unet_ID']) | |
| if 'unet_deno' in magicface_weights: | |
| self.unet_deno.load_state_dict(magicface_weights['unet_deno']) | |
| # Option 2: If it's just one UNet | |
| elif 'state_dict' in magicface_weights: | |
| self.unet_deno.load_state_dict(magicface_weights['state_dict']) | |
| # Option 3: If it's the raw state dict | |
| else: | |
| self.unet_deno.load_state_dict(magicface_weights) | |
| print(" β YOUR weights loaded successfully!") | |
| else: | |
| print(" β οΈ Unexpected checkpoint format, using default weights") | |
| except Exception as e: | |
| print(f" β οΈ Could not load your weights: {str(e)}") | |
| print(" β οΈ Using default pretrained weights from mengtingwei/magicface") | |
| # Set to eval mode | |
| self.vae.requires_grad_(False) | |
| self.text_encoder.requires_grad_(False) | |
| self.unet_ID.requires_grad_(False) | |
| self.unet_deno.requires_grad_(False) | |
| self.vae.eval() | |
| self.text_encoder.eval() | |
| self.unet_ID.eval() | |
| self.unet_deno.eval() | |
| # Create pipeline | |
| print(" - Creating MagicFace pipeline...") | |
| self.pipeline = MgPipeline.from_pretrained( | |
| sd_model, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| unet_ID=self.unet_ID, | |
| unet_deno=self.unet_deno, | |
| safety_checker=None, | |
| torch_dtype=self.weight_dtype, | |
| ).to(self.device) | |
| # Set scheduler | |
| self.pipeline.scheduler = UniPCMultistepScheduler.from_config( | |
| self.pipeline.scheduler.config | |
| ) | |
| self.pipeline.set_progress_bar_config(disable=False) | |
| print("β MagicFace loaded successfully!") | |
| except Exception as e: | |
| print(f"β Error loading MagicFace: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def tokenize_caption(self, caption: str): | |
| """Tokenize text prompt""" | |
| inputs = self.tokenizer( | |
| caption, | |
| max_length=self.tokenizer.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| return inputs.input_ids.to(self.device) | |
| def prepare_au_vector(self, au_params: dict): | |
| """Convert AU parameters dict to tensor""" | |
| au_prompt = np.zeros((12,)) | |
| for au_name, value in au_params.items(): | |
| if au_name in self.ind_dict: | |
| au_prompt[self.ind_dict[au_name]] = float(value) | |
| print(f" π AU vector: {au_prompt}") | |
| return torch.from_numpy(au_prompt).float().unsqueeze(0).to(self.device) | |
| def transform(self, source_image, bg_image, au_params, num_inference_steps=50, seed=424): | |
| """ | |
| Transform facial expression using MagicFace | |
| Args: | |
| source_image: PIL Image (512x512, cropped face) | |
| bg_image: PIL Image (512x512, background) | |
| au_params: dict like {"AU6": 2.0, "AU12": 2.0} | |
| num_inference_steps: number of diffusion steps | |
| seed: random seed | |
| Returns: | |
| PIL Image (transformed) | |
| """ | |
| print(f"π Starting MagicFace transformation...") | |
| print(f" AU params: {au_params}") | |
| print(f" Inference steps: {num_inference_steps}") | |
| try: | |
| # Prepare inputs | |
| transform = transforms.ToTensor() | |
| source = transform(source_image).unsqueeze(0).to(self.device) | |
| bg = transform(bg_image).unsqueeze(0).to(self.device) | |
| # Get text embeddings | |
| prompt = "A close up of a person." | |
| prompt_ids = self.tokenize_caption(prompt) | |
| prompt_embeds = self.text_encoder(prompt_ids)[0] | |
| # Prepare AU vector | |
| au_vector = self.prepare_au_vector(au_params) | |
| # Set seed | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| # Run inference | |
| print(" π Running diffusion pipeline...") | |
| result = self.pipeline( | |
| prompt_embeds=prompt_embeds, | |
| source=source, | |
| bg=bg, | |
| au=au_vector, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| ) | |
| print("β Transformation complete!") | |
| return result.images[0] | |
| except Exception as e: | |
| print(f"β Transformation error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise |