broadfield-dev commited on
Commit
5a1196d
·
verified ·
1 Parent(s): 21d341d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -93
app.py CHANGED
@@ -9,7 +9,10 @@ import subprocess
9
  from datetime import datetime
10
  from pathlib import Path
11
  from huggingface_hub import HfApi
12
- from transformers import AutoConfig, AutoTokenizer # Keep AutoTokenizer for ONNX pipeline
 
 
 
13
 
14
  # --- SETUP ---
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -23,34 +26,39 @@ OUTPUT_DIR = "optimized_models"
23
  os.makedirs(OUTPUT_DIR, exist_ok=True)
24
 
25
  # --- LLAMA.CPP SETUP ---
 
26
  LLAMA_CPP_DIR = Path("llama.cpp")
 
 
27
 
28
  def setup_llama_cpp():
29
- """Clones llama.cpp if not already present and builds it."""
30
  if not LLAMA_CPP_DIR.exists():
31
  logging.info("Cloning llama.cpp repository...")
32
  try:
33
  subprocess.run(["git", "clone", "https://github.com/ggerganov/llama.cpp.git"], check=True, capture_output=True, text=True)
34
  logging.info("llama.cpp cloned successfully.")
35
- logging.info("Building llama.cpp...")
36
- # Build the required tools
37
- subprocess.run(["make", "-C", "llama.cpp", "quantize", "convert.py"], check=True, capture_output=True, text=True)
38
- logging.info("llama.cpp built successfully.")
39
  except subprocess.CalledProcessError as e:
40
- error_msg = f"Failed to clone or build llama.cpp. This is required for GGUF conversion. Error: {e.stderr}"
 
 
 
 
 
 
 
 
 
 
 
41
  logging.error(error_msg, exc_info=True)
42
  raise RuntimeError(error_msg)
43
 
44
  # Run setup on script start
45
  try:
46
  setup_llama_cpp()
47
- LLAMA_CPP_CONVERT_SCRIPT = LLAMA_CPP_DIR / "convert.py"
48
- LLAMA_CPP_QUANTIZE_SCRIPT = LLAMA_CPP_DIR / "quantize" # This is a binary, not a python script
49
- if not LLAMA_CPP_CONVERT_SCRIPT.exists() or not LLAMA_CPP_QUANTIZE_SCRIPT.exists():
50
- raise RuntimeError("llama.cpp scripts/binaries not found after setup.")
51
  except Exception as e:
52
  logging.error(f"FATAL ERROR during llama.cpp setup: {e}", exc_info=True)
53
- # The app will likely fail to start, which is appropriate.
54
 
55
 
56
  def stage_1_analyze_model(model_id: str):
@@ -71,68 +79,66 @@ def stage_1_analyze_model(model_id: str):
71
  logging.error(error_msg)
72
  return log_stream + error_msg, "Could not analyze model.", gr.Accordion(open=False)
73
 
74
- def stage_3_4_onnx_quantize(model_id: str, onnx_quant_type: str, calibration_data_path: str):
75
- # MODIFIED: Takes model_id directly
76
- log_stream = "[STAGE 2 & 3] Converting to ONNX and Quantizing...\n"
 
 
 
 
 
 
 
 
 
 
 
77
  run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
78
- model_name = model_id.split('/')[-1]
79
- onnx_base_path = os.path.join(OUTPUT_DIR, f"{model_name}-{run_id}-onnx-unquantized")
80
-
81
  try:
82
- log_stream += f"Executing `optimum-cli export onnx` for model '{model_id}'...\n"
83
- export_command = ["optimum-cli", "export", "onnx", "--model", model_id, "--trust-remote-code", onnx_base_path]
84
  process = subprocess.run(export_command, check=True, capture_output=True, text=True)
85
  log_stream += process.stdout
86
  if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
87
- log_stream += f"Successfully exported to ONNX at: {onnx_base_path}\n"
88
  except subprocess.CalledProcessError as e:
89
  raise RuntimeError(f"Failed during `optimum-cli export onnx`. Error:\n{e.stderr}")
90
 
91
  try:
92
- log_stream += f"Executing `optimum-cli onnx quantize` for model at '{onnx_base_path}'...\n"
93
- quantized_path = os.path.join(OUTPUT_DIR, f"{model_name}-{run_id}-onnx-quantized")
94
- quantize_command = ["optimum-cli", "onnx", "quantize", "--onnx_model", onnx_base_path, "--avx512", "-o", quantized_path]
 
 
 
 
95
 
96
- if onnx_quant_type == "Static" and calibration_data_path:
97
- log_stream += "Using STATIC quantization with provided calibration data.\n"
98
- # NOTE: optimum-cli quantization is more complex for static. This example simplifies to dynamic.
99
- # For a real implementation, you would need to construct a more complex calibration configuration.
100
- # For stability in a public space, we'll stick to the more reliable dynamic quantization.
101
- log_stream += "[WARNING] Static quantization via CLI is complex and not fully implemented in this UI. Falling back to dynamic.\n"
102
- quantize_command.append("--dynamic")
103
- else:
104
- log_stream += "Using DYNAMIC quantization...\n"
105
- quantize_command.append("--dynamic")
106
-
107
- process = subprocess.run(quantize_command, check=True, capture_output=True, text=True)
108
- log_stream += process.stdout
109
- if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
110
-
111
- # Copy tokenizer config
112
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
113
- tokenizer.save_pretrained(quantized_path)
114
- log_stream += f"Successfully quantized model and saved tokenizer to: {quantized_path}\n"
115
  return quantized_path, log_stream
116
- except subprocess.CalledProcessError as e:
117
- raise RuntimeError(f"Failed during `optimum-cli onnx quantize`. Error:\n{e.stderr}")
118
  except Exception as e:
119
- raise RuntimeError(f"An unexpected error occurred during ONNX processing. Error: {e}")
120
 
121
- def stage_3_4_gguf_quantize(model_id: str, quantization_strategy: str):
122
- # MODIFIED: Takes model_id directly
123
- log_stream = "[STAGE 2 & 3] Converting to GGUF using llama.cpp...\n"
124
  run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
125
- model_name_sanitized = model_id.replace('/', '_')
126
- gguf_output_dir = os.path.join(OUTPUT_DIR, f"{model_name_sanitized}-{run_id}-gguf")
127
- os.makedirs(gguf_output_dir, exist_ok=True)
128
 
129
- f16_gguf_path = os.path.join(gguf_output_dir, "model-f16.gguf")
130
- final_quantized_gguf_path = os.path.join(gguf_output_dir, "model.gguf")
131
 
132
  try:
133
- log_stream += "Executing llama.cpp convert.py script...\n"
134
- # The convert script can take the model ID directly and will use the cache
135
- convert_command = ["python3", str(LLAMA_CPP_CONVERT_SCRIPT), model_id, "--outfile", f16_gguf_path, "--outtype", "f16"]
136
  process = subprocess.run(convert_command, check=True, capture_output=True, text=True)
137
  log_stream += process.stdout
138
  if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
@@ -142,23 +148,23 @@ def stage_3_4_gguf_quantize(model_id: str, quantization_strategy: str):
142
 
143
  if target_quant_name == "F16":
144
  log_stream += "Target is F16, renaming file...\n"
145
- os.rename(f16_gguf_path, final_quantized_gguf_path)
146
  else:
147
  log_stream += f"Quantizing FP16 GGUF to {target_quant_name}...\n"
148
- quantize_command = [str(LLAMA_CPP_QUANTIZE_SCRIPT), f16_gguf_path, final_quantized_gguf_path, target_quant_name]
149
  process = subprocess.run(quantize_command, check=True, capture_output=True, text=True)
150
  log_stream += process.stdout
151
  if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
152
- os.remove(f16_gguf_path) # Clean up intermediate file
153
- return gguf_output_dir, log_stream
154
  except subprocess.CalledProcessError as e:
155
  raise RuntimeError(f"Failed during llama.cpp execution. Error:\n{e.stderr}")
156
  except Exception as e:
157
  raise RuntimeError(f"An unexpected error occurred during GGUF conversion. Error: {e}")
158
 
159
  def stage_5_package_and_upload(model_id: str, optimized_model_path: str, pipeline_log: str, options: dict):
160
- # This function remains mostly correct, just updated placeholder for pruning
161
- log_stream = "[STAGE 4] Packaging and Uploading...\n"
162
  if not HF_TOKEN:
163
  return "Skipping upload: HF_TOKEN not found.", log_stream + "Skipping upload: HF_TOKEN not found."
164
  try:
@@ -166,44 +172,68 @@ def stage_5_package_and_upload(model_id: str, optimized_model_path: str, pipelin
166
  repo_url = api.create_repo(repo_id=repo_name, exist_ok=True, token=HF_TOKEN)
167
  template_file = "model_card_template_gguf.md" if options['pipeline_type'] == "GGUF" else "model_card_template.md"
168
  with open(template_file, "r", encoding="utf-8") as f: template_content = f.read()
169
- # Updated pruning status to be hardcoded as disabled
170
- model_card_content = template_content.format(repo_name=repo_name, model_id=model_id, optimization_date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), pruning_status="Disabled", pruning_percent=0, quant_type=options.get('quant_type', 'N/A'), repo_id=repo_url.repo_id, pipeline_log=pipeline_log)
171
  with open(os.path.join(optimized_model_path, "README.md"), "w", encoding="utf-8") as f: f.write(model_card_content)
172
  api.upload_folder(folder_path=optimized_model_path, repo_id=repo_url.repo_id, repo_type="model", token=HF_TOKEN)
173
- log_stream += f"Upload complete to {repo_url.repo_id}.\n"
174
  return f"Success! Your optimized model is available at: huggingface.co/{repo_url.repo_id}", log_stream
175
  except Exception as e:
176
  raise RuntimeError(f"Failed to upload to the Hub. Error: {e}")
177
 
178
- def run_amop_pipeline(model_id: str, pipeline_type: str, onnx_quant_type: str, calibration_file, gguf_quant_type: str):
179
- # REFACTORED: Removed pruning and in-memory model loading
180
  if not model_id:
181
  yield {log_output: "Please enter a Model ID.", final_output: "Idle"}
182
  return
183
 
184
- initial_log = f"[START] AMOP {pipeline_type} Pipeline Initiated for model '{model_id}'.\n"
185
  yield {run_button: gr.Button(interactive=False, value="🚀 Running..."), analyze_button: gr.Button(interactive=False), final_output: f"RUNNING ({pipeline_type})", log_output: initial_log}
186
 
187
  full_log = initial_log
 
 
 
188
  try:
189
  whoami = api.whoami(token=HF_TOKEN)
190
  if not whoami: raise RuntimeError("Could not authenticate with Hugging Face Hub. Check your HF_TOKEN.")
191
  repo_id_for_link = f"{whoami['name']}/{model_id.split('/')[-1]}-amop-cpu-{pipeline_type.lower()}"
192
 
193
- # The pipeline now has fewer, more robust steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  if pipeline_type == "ONNX":
195
- full_log += "Starting ONNX Conversion & Quantization...\n"; yield {final_output: "Converting to ONNX (1/3)", log_output: full_log}
196
- optimized_path, log = stage_3_4_onnx_quantize(model_id, onnx_quant_type, calibration_file.name if onnx_quant_type == "Static" and calibration_file else None)
197
- options = {'pipeline_type': 'ONNX', 'quant_type': onnx_quant_type}
198
  elif pipeline_type == "GGUF":
199
- full_log += "Starting GGUF Conversion & Quantization...\n"; yield {final_output: "Converting to GGUF (1/3)", log_output: full_log}
200
- optimized_path, log = stage_3_4_gguf_quantize(model_id, gguf_quant_type)
201
- options = {'pipeline_type': 'GGUF', 'quant_type': gguf_quant_type}
202
  else:
203
  raise ValueError("Invalid pipeline type selected.")
204
  full_log += log
205
 
206
- full_log += "Packaging & Uploading...\n"; yield {final_output: "Packaging & Uploading (2/3)", log_output: full_log}
 
207
  final_message, log = stage_5_package_and_upload(model_id, optimized_path, full_log, options)
208
  full_log += log
209
 
@@ -213,11 +243,9 @@ def run_amop_pipeline(model_id: str, pipeline_type: str, onnx_quant_type: str, c
213
  full_log += f"\n[ERROR] Pipeline failed: {e}"
214
  yield {final_output: gr.update(value="ERROR", label="Status"), log_output: full_log, success_box: gr.Markdown(f"❌ **An error occurred.** Check logs for details.", visible=True), run_button: gr.Button(interactive=True, value="Run Optimization Pipeline", variant="primary"), analyze_button: gr.Button(interactive=True, value="Analyze Model")}
215
  finally:
216
- # Clean up entire output directory to save space
217
- if os.path.exists(OUTPUT_DIR):
218
- shutil.rmtree(OUTPUT_DIR)
219
- os.makedirs(OUTPUT_DIR, exist_ok=True)
220
-
221
 
222
  # --- GRADIO UI ---
223
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -231,15 +259,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
231
  with gr.Accordion("⚙️ 2. Configure Optimization", open=False) as optimization_accordion:
232
  analysis_report_output = gr.Markdown()
233
  pipeline_type_radio = gr.Radio(["ONNX", "GGUF"], label="Select Optimization Pipeline")
234
- # Pruning is removed for stability on HF Spaces
235
- # prune_checkbox = gr.Checkbox(label="Enable Pruning", value=False, info="Removes redundant weights.", visible=True)
236
- # prune_slider = gr.Slider(minimum=0, maximum=90, value=20, step=5, label="Pruning Percentage (%)", visible=True)
237
- gr.Markdown("<p style='color:grey;font-size:0.9em;'>Note: Pruning has been disabled to ensure stability on resource-constrained hardware.</p>")
238
  with gr.Group(visible=False) as onnx_options:
239
  gr.Markdown("#### ONNX Options")
240
- onnx_quant_radio = gr.Radio(["Dynamic"], label="Quantization Type", value="Dynamic", info="Static quantization is not supported in this version.") # Simplified
241
- # Hiding calibration for now as it adds complexity
242
- # calibration_file_upload = gr.File(label="Upload Calibration Data (.txt)", visible=False, file_types=['.txt'])
243
  with gr.Group(visible=False) as gguf_options:
244
  gr.Markdown("#### GGUF Options")
245
  gguf_quant_dropdown = gr.Dropdown(["q4_k_m", "q5_k_m", "q8_0", "f16"], label="Quantization Strategy", value="q4_k_m")
@@ -252,14 +279,15 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
252
 
253
  def update_ui_for_pipeline(pipeline_type):
254
  return {onnx_options: gr.Group(visible=pipeline_type=="ONNX"), gguf_options: gr.Group(visible=pipeline_type=="GGUF")}
255
-
256
  pipeline_type_radio.change(fn=update_ui_for_pipeline, inputs=pipeline_type_radio, outputs=[onnx_options, gguf_options])
257
  analyze_button.click(fn=stage_1_analyze_model, inputs=[model_id_input], outputs=[log_output, analysis_report_output, optimization_accordion])
258
- # MODIFIED: Removed pruning inputs from the click function
 
259
  run_button.click(fn=run_amop_pipeline,
260
- inputs=[model_id_input, pipeline_type_radio, onnx_quant_radio, gr.State(None), gguf_quant_dropdown], # Using gr.State(None) as placeholder for removed file upload
261
  outputs=[run_button, analyze_button, final_output, log_output, success_box])
262
 
263
  if __name__ == "__main__":
264
- # IMPORTANT: Added .queue() for handling long-running jobs
265
  demo.queue().launch(debug=True)
 
9
  from datetime import datetime
10
  from pathlib import Path
11
  from huggingface_hub import HfApi
12
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
13
+ from optimum.onnxruntime import ORTQuantizer
14
+ from optimum.onnxruntime.configuration import AutoQuantizationConfig
15
+ import torch.nn.utils.prune as prune
16
 
17
  # --- SETUP ---
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
26
  os.makedirs(OUTPUT_DIR, exist_ok=True)
27
 
28
  # --- LLAMA.CPP SETUP ---
29
+ ## FIX: Define paths at the global scope so all functions can access them.
30
  LLAMA_CPP_DIR = Path("llama.cpp")
31
+ LLAMA_CPP_CONVERT_SCRIPT = LLAMA_CPP_DIR / "convert.py"
32
+ LLAMA_CPP_QUANTIZE_SCRIPT = LLAMA_CPP_DIR / "quantize" # This is a compiled binary
33
 
34
  def setup_llama_cpp():
35
+ """Clones and builds llama.cpp if not already present."""
36
  if not LLAMA_CPP_DIR.exists():
37
  logging.info("Cloning llama.cpp repository...")
38
  try:
39
  subprocess.run(["git", "clone", "https://github.com/ggerganov/llama.cpp.git"], check=True, capture_output=True, text=True)
40
  logging.info("llama.cpp cloned successfully.")
 
 
 
 
41
  except subprocess.CalledProcessError as e:
42
+ error_msg = f"Failed to clone llama.cpp. Error: {e.stderr}"
43
+ logging.error(error_msg, exc_info=True)
44
+ raise RuntimeError(error_msg)
45
+
46
+ if not LLAMA_CPP_QUANTIZE_SCRIPT.exists():
47
+ logging.info("llama.cpp 'quantize' binary not found. Attempting to build...")
48
+ try:
49
+ # Use 'make' to build the necessary tools
50
+ subprocess.run(["make", "-C", str(LLAMA_CPP_DIR), "quantize"], check=True, capture_output=True, text=True)
51
+ logging.info("'quantize' binary built successfully.")
52
+ except subprocess.CalledProcessError as e:
53
+ error_msg = f"Failed to build llama.cpp 'quantize' binary. Error: {e.stderr}"
54
  logging.error(error_msg, exc_info=True)
55
  raise RuntimeError(error_msg)
56
 
57
  # Run setup on script start
58
  try:
59
  setup_llama_cpp()
 
 
 
 
60
  except Exception as e:
61
  logging.error(f"FATAL ERROR during llama.cpp setup: {e}", exc_info=True)
 
62
 
63
 
64
  def stage_1_analyze_model(model_id: str):
 
79
  logging.error(error_msg)
80
  return log_stream + error_msg, "Could not analyze model.", gr.Accordion(open=False)
81
 
82
+ ## RE-INTEGRATED: This function is brought back from your original code.
83
+ def stage_2_prune_model(model, prune_percentage: float):
84
+ if prune_percentage == 0:
85
+ return model, "Skipped pruning as percentage was 0."
86
+ log_stream = "[STAGE 2] Pruning model...\n"
87
+ for name, module in model.named_modules():
88
+ if isinstance(module, torch.nn.Linear):
89
+ prune.l1_unstructured(module, name='weight', amount=prune_percentage / 100.0)
90
+ prune.remove(module, 'weight')
91
+ log_stream += f"Pruning complete with {prune_percentage}% target.\n"
92
+ return model, log_stream
93
+
94
+ def stage_3_4_onnx_quantize(model_path_or_id: str, onnx_quant_type: str, calibration_data_path: str):
95
+ log_stream = "[STAGE 3 & 4] Converting to ONNX and Quantizing...\n"
96
  run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
97
+ model_name = model_path_or_id.split('/')[-1]
98
+ onnx_path = os.path.join(OUTPUT_DIR, f"{model_name}-{run_id}-onnx")
99
+
100
  try:
101
+ log_stream += f"Executing `optimum-cli export onnx` for '{model_path_or_id}'...\n"
102
+ export_command = ["optimum-cli", "export", "onnx", "--model", model_path_or_id, "--trust-remote-code", onnx_path]
103
  process = subprocess.run(export_command, check=True, capture_output=True, text=True)
104
  log_stream += process.stdout
105
  if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
106
+ log_stream += f"Successfully exported to ONNX at: {onnx_path}\n"
107
  except subprocess.CalledProcessError as e:
108
  raise RuntimeError(f"Failed during `optimum-cli export onnx`. Error:\n{e.stderr}")
109
 
110
  try:
111
+ # For simplicity and stability on HF Spaces, we will only use Dynamic Quantization via CLI.
112
+ quantizer = ORTQuantizer.from_pretrained(onnx_path)
113
+ log_stream += "Performing DYNAMIC quantization...\n"
114
+ dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
115
+ quantized_path = os.path.join(onnx_path, "quantized-dynamic")
116
+ quantizer.quantize(save_dir=quantized_path, quantization_config=dqconfig)
117
+ log_stream += f"Successfully quantized model to: {quantized_path}\n"
118
 
119
+ # If the original input was a model_id, we need to save a new tokenizer.
120
+ # If it was a local path (from pruning), the tokenizer is already there.
121
+ if not os.path.exists(os.path.join(quantized_path, 'tokenizer_config.json')):
122
+ AutoTokenizer.from_pretrained(model_path_or_id, trust_remote_code=True).save_pretrained(quantized_path)
123
+ log_stream += "Saved new tokenizer files.\n"
124
+
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  return quantized_path, log_stream
 
 
126
  except Exception as e:
127
+ raise RuntimeError(f"Failed during ONNX quantization step. Error: {e}")
128
 
129
+ def stage_3_4_gguf_quantize(model_path_or_id: str, original_model_id: str, quantization_strategy: str):
130
+ log_stream = "[STAGE 3 & 4] Converting to GGUF using llama.cpp...\n"
 
131
  run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
132
+ model_name = original_model_id.replace('/', '_') # Use original ID for consistent naming
133
+ gguf_path = os.path.join(OUTPUT_DIR, f"{model_name}-{run_id}-gguf")
134
+ os.makedirs(gguf_path, exist_ok=True)
135
 
136
+ f16_gguf_path = os.path.join(gguf_path, "model-f16.gguf")
137
+ quantized_gguf_path = os.path.join(gguf_path, "model.gguf")
138
 
139
  try:
140
+ log_stream += f"Executing llama.cpp convert.py script on '{model_path_or_id}'...\n"
141
+ convert_command = ["python3", str(LLAMA_CPP_CONVERT_SCRIPT), model_path_or_id, "--outfile", f16_gguf_path, "--outtype", "f16"]
 
142
  process = subprocess.run(convert_command, check=True, capture_output=True, text=True)
143
  log_stream += process.stdout
144
  if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
 
148
 
149
  if target_quant_name == "F16":
150
  log_stream += "Target is F16, renaming file...\n"
151
+ os.rename(f16_gguf_path, quantized_gguf_path)
152
  else:
153
  log_stream += f"Quantizing FP16 GGUF to {target_quant_name}...\n"
154
+ quantize_command = [str(LLAMA_CPP_QUANTIZE_SCRIPT), f16_gguf_path, quantized_gguf_path, target_quant_name]
155
  process = subprocess.run(quantize_command, check=True, capture_output=True, text=True)
156
  log_stream += process.stdout
157
  if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
158
+ os.remove(f16_gguf_path)
159
+ return gguf_path, log_stream
160
  except subprocess.CalledProcessError as e:
161
  raise RuntimeError(f"Failed during llama.cpp execution. Error:\n{e.stderr}")
162
  except Exception as e:
163
  raise RuntimeError(f"An unexpected error occurred during GGUF conversion. Error: {e}")
164
 
165
  def stage_5_package_and_upload(model_id: str, optimized_model_path: str, pipeline_log: str, options: dict):
166
+ # This function is correct from your original version
167
+ log_stream = "[STAGE 5] Packaging and Uploading...\n"
168
  if not HF_TOKEN:
169
  return "Skipping upload: HF_TOKEN not found.", log_stream + "Skipping upload: HF_TOKEN not found."
170
  try:
 
172
  repo_url = api.create_repo(repo_id=repo_name, exist_ok=True, token=HF_TOKEN)
173
  template_file = "model_card_template_gguf.md" if options['pipeline_type'] == "GGUF" else "model_card_template.md"
174
  with open(template_file, "r", encoding="utf-8") as f: template_content = f.read()
175
+ model_card_content = template_content.format(repo_name=repo_name, model_id=model_id, optimization_date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), pruning_status="Enabled" if options.get('prune', False) else "Disabled", pruning_percent=options.get('prune_percent', 0), quant_type=options.get('quant_type', 'N/A'), repo_id=repo_url.repo_id, pipeline_log=pipeline_log)
 
176
  with open(os.path.join(optimized_model_path, "README.md"), "w", encoding="utf-8") as f: f.write(model_card_content)
177
  api.upload_folder(folder_path=optimized_model_path, repo_id=repo_url.repo_id, repo_type="model", token=HF_TOKEN)
178
+ log_stream += "Upload complete.\n"
179
  return f"Success! Your optimized model is available at: huggingface.co/{repo_url.repo_id}", log_stream
180
  except Exception as e:
181
  raise RuntimeError(f"Failed to upload to the Hub. Error: {e}")
182
 
183
+ ## RE-INTEGRATED: The main pipeline function now handles both pruning and no-pruning paths.
184
+ def run_amop_pipeline(model_id: str, pipeline_type: str, do_prune: bool, prune_percent: float, onnx_quant_type: str, calibration_file, gguf_quant_type: str):
185
  if not model_id:
186
  yield {log_output: "Please enter a Model ID.", final_output: "Idle"}
187
  return
188
 
189
+ initial_log = f"[START] AMOP {pipeline_type} Pipeline Initiated for '{model_id}'.\n"
190
  yield {run_button: gr.Button(interactive=False, value="🚀 Running..."), analyze_button: gr.Button(interactive=False), final_output: f"RUNNING ({pipeline_type})", log_output: initial_log}
191
 
192
  full_log = initial_log
193
+ temp_model_dir = None
194
+ model_path_or_id = model_id # Default to memory-efficient path
195
+
196
  try:
197
  whoami = api.whoami(token=HF_TOKEN)
198
  if not whoami: raise RuntimeError("Could not authenticate with Hugging Face Hub. Check your HF_TOKEN.")
199
  repo_id_for_link = f"{whoami['name']}/{model_id.split('/')[-1]}-amop-cpu-{pipeline_type.lower()}"
200
 
201
+ # --- STAGE 2: OPTIONAL PRUNING (Memory-intensive) ---
202
+ if do_prune and prune_percent > 0:
203
+ full_log += f"\n[WARNING] Pruning is memory-intensive and may fail for large models.\n"
204
+ full_log += "Loading base model for pruning...\n"; yield {final_output: "Loading model (1/5)", log_output: full_log}
205
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
206
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
207
+ full_log += f"Successfully loaded '{model_id}'.\n"
208
+
209
+ yield {final_output: "Pruning model (2/5)", log_output: full_log}
210
+ model, log = stage_2_prune_model(model, prune_percent)
211
+ full_log += log
212
+
213
+ # Save pruned model to a temporary directory for the next stage
214
+ temp_model_dir = tempfile.mkdtemp()
215
+ model.save_pretrained(temp_model_dir)
216
+ tokenizer.save_pretrained(temp_model_dir)
217
+ model_path_or_id = temp_model_dir # Next stages will use this local path
218
+ full_log += f"Saved intermediate pruned model to {temp_model_dir}\n"
219
+ else:
220
+ full_log += "Pruning skipped.\n"
221
+
222
+ # --- STAGE 3 & 4: CONVERSION & QUANTIZATION ---
223
  if pipeline_type == "ONNX":
224
+ full_log += "Converting to ONNX...\n"; yield {final_output: "Converting to ONNX (3/5)", log_output: full_log}
225
+ optimized_path, log = stage_3_4_onnx_quantize(model_path_or_id, onnx_quant_type, calibration_file.name if onnx_quant_type == "Static" and calibration_file else None)
226
+ options = {'pipeline_type': 'ONNX', 'prune': do_prune, 'prune_percent': prune_percent, 'quant_type': onnx_quant_type}
227
  elif pipeline_type == "GGUF":
228
+ full_log += "Converting to GGUF...\n"; yield {final_output: "Converting to GGUF (3/5)", log_output: full_log}
229
+ optimized_path, log = stage_3_4_gguf_quantize(model_path_or_id, model_id, gguf_quant_type)
230
+ options = {'pipeline_type': 'GGUF', 'prune': do_prune, 'prune_percent': prune_percent, 'quant_type': gguf_quant_type}
231
  else:
232
  raise ValueError("Invalid pipeline type selected.")
233
  full_log += log
234
 
235
+ # --- STAGE 5: UPLOAD ---
236
+ full_log += "Packaging & Uploading...\n"; yield {final_output: "Packaging & Uploading (4/5)", log_output: full_log}
237
  final_message, log = stage_5_package_and_upload(model_id, optimized_path, full_log, options)
238
  full_log += log
239
 
 
243
  full_log += f"\n[ERROR] Pipeline failed: {e}"
244
  yield {final_output: gr.update(value="ERROR", label="Status"), log_output: full_log, success_box: gr.Markdown(f"❌ **An error occurred.** Check logs for details.", visible=True), run_button: gr.Button(interactive=True, value="Run Optimization Pipeline", variant="primary"), analyze_button: gr.Button(interactive=True, value="Analyze Model")}
245
  finally:
246
+ # Clean up the temporary directory if it was created
247
+ if temp_model_dir and os.path.exists(temp_model_dir):
248
+ shutil.rmtree(temp_model_dir)
 
 
249
 
250
  # --- GRADIO UI ---
251
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
259
  with gr.Accordion("⚙️ 2. Configure Optimization", open=False) as optimization_accordion:
260
  analysis_report_output = gr.Markdown()
261
  pipeline_type_radio = gr.Radio(["ONNX", "GGUF"], label="Select Optimization Pipeline")
262
+ ## RE-INTEGRATED: Pruning UI elements are back.
263
+ gr.Warning("Pruning requires high RAM and may fail for models >2B parameters on free Spaces.")
264
+ prune_checkbox = gr.Checkbox(label="Enable Pruning (Optional)", value=False, info="Removes redundant weights before quantization.")
265
+ prune_slider = gr.Slider(minimum=0, maximum=90, value=20, step=5, label="Pruning Percentage (%)", visible=True)
266
  with gr.Group(visible=False) as onnx_options:
267
  gr.Markdown("#### ONNX Options")
268
+ onnx_quant_radio = gr.Radio(["Dynamic"], label="Quantization Type", value="Dynamic", info="Static quantization via UI is not supported.")
269
+ calibration_file_upload = gr.File(visible=False) # Keep element for function signature, but hide
 
270
  with gr.Group(visible=False) as gguf_options:
271
  gr.Markdown("#### GGUF Options")
272
  gguf_quant_dropdown = gr.Dropdown(["q4_k_m", "q5_k_m", "q8_0", "f16"], label="Quantization Strategy", value="q4_k_m")
 
279
 
280
  def update_ui_for_pipeline(pipeline_type):
281
  return {onnx_options: gr.Group(visible=pipeline_type=="ONNX"), gguf_options: gr.Group(visible=pipeline_type=="GGUF")}
282
+
283
  pipeline_type_radio.change(fn=update_ui_for_pipeline, inputs=pipeline_type_radio, outputs=[onnx_options, gguf_options])
284
  analyze_button.click(fn=stage_1_analyze_model, inputs=[model_id_input], outputs=[log_output, analysis_report_output, optimization_accordion])
285
+
286
+ ## RE-INTEGRATED: Pruning inputs are now passed to the pipeline function.
287
  run_button.click(fn=run_amop_pipeline,
288
+ inputs=[model_id_input, pipeline_type_radio, prune_checkbox, prune_slider, onnx_quant_radio, calibration_file_upload, gguf_quant_dropdown],
289
  outputs=[run_button, analyze_button, final_output, log_output, success_box])
290
 
291
  if __name__ == "__main__":
292
+ # Use .queue() to handle long-running tasks and prevent timeouts
293
  demo.queue().launch(debug=True)