gauravvjhaa commited on
Commit
3ece5ec
Β·
1 Parent(s): b6a80c5

Add real MagicFace model structure (simplified for now)

Browse files
Files changed (3) hide show
  1. app.py +99 -181
  2. magicface_model.py +130 -0
  3. requirements.txt +4 -3
app.py CHANGED
@@ -1,68 +1,27 @@
1
  import gradio as gr
2
  import torch
3
- import numpy as np
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
7
  import json
8
- from huggingface_hub import hf_hub_download
9
 
10
  print("πŸš€ Starting Affecto Inference Service...")
11
 
12
- # ============================================
13
- # MODEL LOADING
14
- # ============================================
15
 
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
17
  print(f"πŸ–₯️ Device: {device}")
18
 
19
- # Download your model
20
- print("πŸ“₯ Downloading MagicFace model...")
21
- model_path = hf_hub_download(
22
- repo_id="gauravvjhaa/magicface-affecto-model",
23
- filename="79999_iter.pth",
24
- cache_dir="./models"
25
- )
26
- print(f"βœ… Model downloaded to: {model_path}")
27
-
28
- # Load checkpoint
29
- checkpoint = torch.load(model_path, map_location=device)
30
- print(f"πŸ“¦ Checkpoint loaded successfully")
31
 
32
  # ============================================
33
- # IMAGE PROCESSING UTILITIES
34
  # ============================================
35
 
36
- import torchvision.transforms as transforms
37
-
38
- def preprocess_image(image):
39
- """Convert PIL image to tensor"""
40
- if not isinstance(image, Image.Image):
41
- image = Image.fromarray(image)
42
-
43
- if image.mode != 'RGB':
44
- image = image.convert('RGB')
45
-
46
- transform = transforms.Compose([
47
- transforms.Resize((256, 256)),
48
- transforms.ToTensor(),
49
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
50
- ])
51
-
52
- tensor = transform(image).unsqueeze(0)
53
- return tensor.to(device)
54
-
55
- def postprocess_tensor(tensor):
56
- """Convert tensor to PIL image"""
57
- tensor = tensor.squeeze(0).cpu()
58
- tensor = tensor * 0.5 + 0.5
59
- tensor = torch.clamp(tensor, 0, 1)
60
-
61
- numpy_image = tensor.numpy().transpose(1, 2, 0)
62
- numpy_image = (numpy_image * 255).astype(np.uint8)
63
-
64
- return Image.fromarray(numpy_image)
65
-
66
  def pil_to_base64(image):
67
  """Convert PIL to base64"""
68
  buffered = BytesIO()
@@ -75,115 +34,76 @@ def base64_to_pil(base64_str):
75
  return Image.open(BytesIO(image_bytes))
76
 
77
  # ============================================
78
- # TRANSFORMATION
79
  # ============================================
80
 
81
- def apply_emotion_transform(input_tensor, au_params):
82
- """Apply emotion transformation (placeholder)"""
83
- print(f"🎭 Applying transformation with AU params: {au_params}")
84
-
85
- output = input_tensor.clone()
86
-
87
- if "AU12" in au_params:
88
- intensity = au_params["AU12"]
89
- output = output * (1.0 + intensity * 0.2)
90
-
91
- if "AU4" in au_params:
92
- intensity = au_params["AU4"]
93
- output = output * (1.0 - intensity * 0.15)
94
-
95
- output = torch.clamp(output, -1, 1)
96
- return output
97
-
98
- # ============================================
99
- # API FUNCTIONS
100
- # ============================================
101
 
102
- def transform_api(data):
103
  """API function for external calls"""
104
  try:
105
- image_base64 = data["image"]
106
- au_params = data["au_params"]
107
-
108
- print(f"πŸ“₯ Received API request with AU params: {au_params}")
109
 
 
110
  image = base64_to_pil(image_base64)
111
  print(f"πŸ“Έ Image size: {image.size}")
112
 
113
- input_tensor = preprocess_image(image)
114
- output_tensor = apply_emotion_transform(input_tensor, au_params)
115
- result_image = postprocess_tensor(output_tensor)
 
 
 
 
 
 
 
 
116
  result_base64 = pil_to_base64(result_image)
117
 
118
  print("βœ… Transformation complete")
119
 
120
- return {
121
- "success": True,
122
- "transformed_image": result_base64,
123
- "au_params": au_params,
124
- "message": "Transformation successful"
125
- }
126
  except Exception as e:
127
  print(f"❌ API Error: {str(e)}")
128
  import traceback
129
  traceback.print_exc()
130
- return {
131
- "success": False,
132
- "error": str(e),
133
- "message": "Transformation failed"
134
- }
135
-
136
- def health_check():
137
- """Health check function"""
138
- return {
139
- "status": "healthy",
140
- "model": "magicface",
141
- "device": str(device),
142
- "version": "1.0.0"
143
- }
144
-
145
- def root_info():
146
- """Root info function"""
147
- return {
148
- "message": "Affecto Inference API",
149
- "status": "running",
150
- "version": "1.0.0",
151
- "endpoints": {
152
- "health": "/health",
153
- "transform": "/transform"
154
- }
155
- }
156
 
157
  # ============================================
158
  # GRADIO INTERFACE
159
  # ============================================
160
 
161
- def transform_gradio(image, au_params_str):
162
- """Gradio interface function"""
163
- try:
164
- au_params = json.loads(au_params_str)
165
- input_tensor = preprocess_image(image)
166
- output_tensor = apply_emotion_transform(input_tensor, au_params)
167
- result_image = postprocess_tensor(output_tensor)
168
- return result_image
169
- except Exception as e:
170
- print(f"❌ Error: {str(e)}")
171
- import traceback
172
- traceback.print_exc()
173
- return image
174
-
175
- # Build Gradio interface
176
- with gr.Blocks(theme=gr.themes.Soft(), title="Affecto Inference API") as demo:
177
- gr.Markdown("# 🎭 Affecto - Emotion Transformation API")
178
- gr.Markdown("Transform facial emotions using MagicFace Action Units")
179
 
180
  with gr.Tab("πŸ–ΌοΈ Web Interface"):
181
  with gr.Row():
182
  with gr.Column():
183
- input_image = gr.Image(type="pil", label="Upload Image")
184
  au_params_input = gr.Textbox(
185
  label="AU Parameters (JSON)",
186
- value='{"AU6": 1.0, "AU12": 1.0}',
187
  lines=3
188
  )
189
  transform_btn = gr.Button("✨ Transform", variant="primary", size="lg")
@@ -191,15 +111,16 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Affecto Inference API") as demo:
191
  with gr.Column():
192
  output_image = gr.Image(type="pil", label="Transformed Result")
193
 
194
- gr.Markdown("### 🎨 Emotion Presets:")
195
  gr.Examples(
196
  examples=[
197
- ['{"AU6": 1.0, "AU12": 1.0}'],
198
- ['{"AU1": 1.0, "AU4": 1.0, "AU15": 1.0}'],
199
- ['{"AU4": 1.0, "AU5": 1.0, "AU7": 1.0, "AU23": 1.0}'],
200
- ['{"AU1": 1.0, "AU2": 1.0, "AU5": 1.0, "AU26": 1.0}'],
201
  ],
202
  inputs=[au_params_input],
 
203
  )
204
 
205
  transform_btn.click(
@@ -210,60 +131,57 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Affecto Inference API") as demo:
210
 
211
  with gr.Tab("πŸ“‘ API Documentation"):
212
  gr.Markdown("""
213
- ## API Endpoints
 
 
 
 
 
 
 
214
 
215
- ### Transform Image
216
- **POST** `/api/transform`
 
217
 
218
- ```json
219
- {
220
- "image": "base64_encoded_image",
221
- "au_params": {"AU6": 1.0, "AU12": 1.0}
222
- }
 
 
 
 
 
 
 
 
223
  ```
224
 
225
- ### Health Check
226
- **GET** `/api/health`
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- Returns service status and model information.
 
 
 
 
229
  """)
230
-
231
- # API endpoints as Gradio functions
232
- with gr.Tab("πŸ”Œ API"):
233
- with gr.Row():
234
- with gr.Column():
235
- gr.Markdown("### POST /api/transform")
236
- api_input = gr.Textbox(
237
- label="Request JSON",
238
- value='{"image": "BASE64_STRING", "au_params": {"AU6": 1.0}}',
239
- lines=5
240
- )
241
- api_btn = gr.Button("Test API")
242
- api_output = gr.JSON(label="Response")
243
-
244
- api_btn.click(
245
- fn=lambda x: transform_api(json.loads(x)),
246
- inputs=[api_input],
247
- outputs=[api_output]
248
- )
249
-
250
- with gr.Column():
251
- gr.Markdown("### GET /api/health")
252
- health_btn = gr.Button("Check Health")
253
- health_output = gr.JSON(label="Health Status")
254
-
255
- health_btn.click(
256
- fn=health_check,
257
- inputs=[],
258
- outputs=[health_output]
259
- )
260
-
261
- # Add API routes using Gradio's API
262
- demo.api_names = ["transform", "health", "root"]
263
 
264
- print("βœ… Affecto Inference API Ready!")
265
  print(f"🌐 Gradio UI: https://gauravvjhaa-affecto-inference.hf.space/")
266
 
267
- # Launch
268
  if __name__ == "__main__":
269
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
  import base64
5
  from io import BytesIO
6
  import json
 
7
 
8
  print("πŸš€ Starting Affecto Inference Service...")
9
 
10
+ # Import our MagicFace model
11
+ from magicface_model import MagicFaceModel
 
12
 
13
+ # Initialize model
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(f"πŸ–₯️ Device: {device}")
16
 
17
+ print("πŸ“₯ Loading MagicFace model...")
18
+ model = MagicFaceModel(device=device)
19
+ print("βœ… Model ready!")
 
 
 
 
 
 
 
 
 
20
 
21
  # ============================================
22
+ # UTILITY FUNCTIONS
23
  # ============================================
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def pil_to_base64(image):
26
  """Convert PIL to base64"""
27
  buffered = BytesIO()
 
34
  return Image.open(BytesIO(image_bytes))
35
 
36
  # ============================================
37
+ # INFERENCE FUNCTIONS
38
  # ============================================
39
 
40
+ def transform_gradio(image, au_params_str):
41
+ """Gradio interface function"""
42
+ try:
43
+ # Parse AU params
44
+ au_params = json.loads(au_params_str)
45
+
46
+ # Ensure image is 512x512
47
+ if image.size != (512, 512):
48
+ image = image.resize((512, 512), Image.LANCZOS)
49
+
50
+ # Transform
51
+ result_image = model.transform(image, au_params)
52
+
53
+ return result_image
54
+ except Exception as e:
55
+ print(f"❌ Error: {str(e)}")
56
+ import traceback
57
+ traceback.print_exc()
58
+ return image
 
59
 
60
+ def transform_api(image_base64, au_params_str):
61
  """API function for external calls"""
62
  try:
63
+ print(f"πŸ“₯ Received API request")
 
 
 
64
 
65
+ # Decode image
66
  image = base64_to_pil(image_base64)
67
  print(f"πŸ“Έ Image size: {image.size}")
68
 
69
+ # Parse AU params
70
+ au_params = json.loads(au_params_str)
71
+
72
+ # Ensure 512x512
73
+ if image.size != (512, 512):
74
+ image = image.resize((512, 512), Image.LANCZOS)
75
+
76
+ # Transform
77
+ result_image = model.transform(image, au_params)
78
+
79
+ # Encode result
80
  result_base64 = pil_to_base64(result_image)
81
 
82
  print("βœ… Transformation complete")
83
 
84
+ return result_base64
 
 
 
 
 
85
  except Exception as e:
86
  print(f"❌ API Error: {str(e)}")
87
  import traceback
88
  traceback.print_exc()
89
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # ============================================
92
  # GRADIO INTERFACE
93
  # ============================================
94
 
95
+ with gr.Blocks(theme=gr.themes.Soft(), title="Affecto MagicFace API") as demo:
96
+ gr.Markdown("# 🎭 Affecto - MagicFace Emotion Transformation")
97
+ gr.Markdown("Transform facial emotions using Action Units (AU)")
98
+ gr.Markdown("⚠️ **Note:** Currently using simplified model. Full MagicFace pipeline coming soon!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  with gr.Tab("πŸ–ΌοΈ Web Interface"):
101
  with gr.Row():
102
  with gr.Column():
103
+ input_image = gr.Image(type="pil", label="Upload Face Image (512x512 recommended)")
104
  au_params_input = gr.Textbox(
105
  label="AU Parameters (JSON)",
106
+ value='{"AU6": 2.0, "AU12": 2.0}',
107
  lines=3
108
  )
109
  transform_btn = gr.Button("✨ Transform", variant="primary", size="lg")
 
111
  with gr.Column():
112
  output_image = gr.Image(type="pil", label="Transformed Result")
113
 
114
+ gr.Markdown("### 🎨 Emotion Presets (click to use):")
115
  gr.Examples(
116
  examples=[
117
+ ['{"AU6": 2.0, "AU12": 2.0}'], # Happy
118
+ ['{"AU1": 2.0, "AU4": 2.0, "AU15": 2.0}'], # Sad
119
+ ['{"AU4": 3.0, "AU5": 2.0, "AU7": 2.0}'], # Angry
120
+ ['{"AU1": 3.0, "AU2": 2.0, "AU5": 3.0, "AU26": 2.0}'], # Surprised
121
  ],
122
  inputs=[au_params_input],
123
+ label="Emotion Presets"
124
  )
125
 
126
  transform_btn.click(
 
131
 
132
  with gr.Tab("πŸ“‘ API Documentation"):
133
  gr.Markdown("""
134
+ ## API Usage
135
+
136
+ ### Gradio API Endpoint
137
+
138
+ ```python
139
+ import requests
140
+ import base64
141
+ import json
142
 
143
+ # Prepare image
144
+ with open("face.jpg", "rb") as f:
145
+ image_base64 = base64.b64encode(f.read()).decode()
146
 
147
+ # Call API
148
+ response = requests.post(
149
+ "https://gauravvjhaa-affecto-inference.hf.space/api/predict",
150
+ json={
151
+ "data": [
152
+ image_base64,
153
+ '{"AU6": 2.0, "AU12": 2.0}'
154
+ ]
155
+ }
156
+ )
157
+
158
+ result = response.json()
159
+ result_image = result["data"][0] # base64 string
160
  ```
161
 
162
+ ### Available Action Units:
163
+ - **AU1** (0): Inner Brow Raiser - Values: 0-4
164
+ - **AU2** (1): Outer Brow Raiser - Values: 0-4
165
+ - **AU4** (2): Brow Lowerer - Values: 0-4
166
+ - **AU5** (3): Upper Lid Raiser - Values: 0-4
167
+ - **AU6** (4): Cheek Raiser - Values: 0-4
168
+ - **AU9** (5): Nose Wrinkler - Values: 0-4
169
+ - **AU12** (6): Lip Corner Puller (Smile) - Values: 0-4
170
+ - **AU15** (7): Lip Corner Depressor - Values: 0-4
171
+ - **AU17** (8): Chin Raiser - Values: 0-4
172
+ - **AU20** (9): Lip Stretcher - Values: 0-4
173
+ - **AU25** (10): Lips Part - Values: 0-4
174
+ - **AU26** (11): Jaw Drop - Values: 0-4
175
 
176
+ ### Example Combinations:
177
+ - **Happy**: `{"AU6": 2, "AU12": 2}`
178
+ - **Sad**: `{"AU1": 2, "AU4": 2, "AU15": 2}`
179
+ - **Angry**: `{"AU4": 3, "AU5": 2, "AU7": 2}`
180
+ - **Surprised**: `{"AU1": 3, "AU2": 2, "AU5": 3, "AU26": 2}`
181
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ print("βœ… Affecto MagicFace API Ready!")
184
  print(f"🌐 Gradio UI: https://gauravvjhaa-affecto-inference.hf.space/")
185
 
 
186
  if __name__ == "__main__":
187
  demo.launch(server_name="0.0.0.0", server_port=7860)
magicface_model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ from diffusers import AutoencoderKL, UniPCMultistepScheduler
6
+ from transformers import CLIPTextModel, CLIPTokenizer
7
+
8
+ # We'll need to implement these custom UNet classes
9
+ # For now, we'll use a simplified version
10
+
11
+ class MagicFaceModel:
12
+ def __init__(self, device='cuda'):
13
+ self.device = device if torch.cuda.is_available() else 'cpu'
14
+ print(f"πŸ–₯️ Initializing MagicFace on: {self.device}")
15
+
16
+ # AU mapping (same as original)
17
+ self.ind_dict = {
18
+ 'AU1':0, 'AU2':1, 'AU4':2, 'AU5':3, 'AU6':4, 'AU9':5,
19
+ 'AU12':6, 'AU15':7, 'AU17':8, 'AU20':9, 'AU25':10, 'AU26':11
20
+ }
21
+
22
+ self.load_models()
23
+
24
+ def load_models(self):
25
+ """Load all required models"""
26
+ print("πŸ“₯ Loading Stable Diffusion components...")
27
+
28
+ # Load VAE
29
+ self.vae = AutoencoderKL.from_pretrained(
30
+ 'runwayml/stable-diffusion-v1-5',
31
+ subfolder="vae",
32
+ torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32
33
+ ).to(self.device)
34
+
35
+ # Load Text Encoder
36
+ self.text_encoder = CLIPTextModel.from_pretrained(
37
+ 'runwayml/stable-diffusion-v1-5',
38
+ subfolder="text_encoder",
39
+ ).to(self.device)
40
+
41
+ # Load Tokenizer
42
+ self.tokenizer = CLIPTokenizer.from_pretrained(
43
+ 'runwayml/stable-diffusion-v1-5',
44
+ subfolder="tokenizer",
45
+ )
46
+
47
+ # TODO: Load custom UNets from mengtingwei/magicface
48
+ # For now, we'll use a simplified approach
49
+ print("⚠️ Using simplified model (custom UNets not yet loaded)")
50
+
51
+ self.vae.requires_grad_(False)
52
+ self.text_encoder.requires_grad_(False)
53
+
54
+ print("βœ… Models loaded successfully")
55
+
56
+ def preprocess_image(self, image: Image.Image):
57
+ """Preprocess image for inference"""
58
+ transform = transforms.Compose([
59
+ transforms.Resize((512, 512)),
60
+ transforms.ToTensor(),
61
+ ])
62
+ return transform(image).unsqueeze(0).to(self.device)
63
+
64
+ def prepare_au_vector(self, au_params: dict):
65
+ """Convert AU parameters dict to numpy array"""
66
+ au_prompt = np.zeros((12,))
67
+
68
+ for au_name, value in au_params.items():
69
+ if au_name in self.ind_dict:
70
+ au_prompt[self.ind_dict[au_name]] = value
71
+
72
+ return torch.from_numpy(au_prompt).float().unsqueeze(0).to(self.device)
73
+
74
+ def tokenize_caption(self, caption: str):
75
+ """Tokenize text prompt"""
76
+ inputs = self.tokenizer(
77
+ caption,
78
+ max_length=self.tokenizer.model_max_length,
79
+ padding="max_length",
80
+ truncation=True,
81
+ return_tensors="pt"
82
+ )
83
+ return inputs.input_ids.to(self.device)
84
+
85
+ @torch.no_grad()
86
+ def transform(self, image: Image.Image, au_params: dict):
87
+ """
88
+ Transform facial expression based on AU parameters
89
+
90
+ Args:
91
+ image: PIL Image (512x512)
92
+ au_params: dict like {"AU6": 1.0, "AU12": 1.0}
93
+
94
+ Returns:
95
+ PIL Image (transformed)
96
+ """
97
+ print(f"🎭 Transforming with AU params: {au_params}")
98
+
99
+ # Preprocess
100
+ source_tensor = self.preprocess_image(image)
101
+ au_vector = self.prepare_au_vector(au_params)
102
+
103
+ # Get text embeddings
104
+ prompt = "A close up of a person."
105
+ prompt_ids = self.tokenize_caption(prompt)
106
+ prompt_embeds = self.text_encoder(prompt_ids)[0]
107
+
108
+ # TODO: Implement full diffusion pipeline with custom UNets
109
+ # For now, return a simple transformation
110
+ print("⚠️ Using simplified transformation (full pipeline not yet implemented)")
111
+
112
+ # Placeholder: Apply simple brightness adjustment based on AUs
113
+ output_tensor = source_tensor.clone()
114
+
115
+ # AU12 (smile) - brighten
116
+ if "AU12" in au_params:
117
+ output_tensor = output_tensor * (1.0 + au_params["AU12"] * 0.3)
118
+
119
+ # AU4 (frown) - darken
120
+ if "AU4" in au_params:
121
+ output_tensor = output_tensor * (1.0 - au_params["AU4"] * 0.2)
122
+
123
+ output_tensor = torch.clamp(output_tensor, 0, 1)
124
+
125
+ # Convert back to PIL
126
+ output_np = output_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
127
+ output_np = (output_np * 255).astype(np.uint8)
128
+ result_image = Image.fromarray(output_np)
129
+
130
+ return result_image
requirements.txt CHANGED
@@ -1,9 +1,10 @@
1
  torch==2.0.1
2
  torchvision==0.15.2
3
  gradio==4.16.0
4
- fastapi==0.109.0
5
- uvicorn[standard]==0.27.0
 
 
6
  Pillow==10.2.0
7
  numpy==1.26.3
8
  huggingface-hub==0.20.3
9
- python-multipart==0.0.6
 
1
  torch==2.0.1
2
  torchvision==0.15.2
3
  gradio==4.16.0
4
+ diffusers==0.21.4
5
+ transformers==4.35.2
6
+ accelerate==0.24.1
7
+ safetensors==0.4.1
8
  Pillow==10.2.0
9
  numpy==1.26.3
10
  huggingface-hub==0.20.3