learn-2 / app.py
urnotwen's picture
Update app.py
c0bb6b3 verified
raw
history blame
2.54 kB
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import timm
import io
import sys
import psutil # 記得 import 這個,才能調用後台硬體資料
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN,如果是 Gated Model 可能會失敗")
try:
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(f"✅ 模型載入成功!使用裝置: {device}")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 圖像處理 ---
def process_image(input_image):
if input_image is None:
return None
# 處理邏輯
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(input_image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3.取得系統狀態的函數 ---
def get_system_stats():
return {"cpu": psutil.cpu_percent(), "ram": psutil.virtual_memory().percent}
# --- 4. 介面 ---
with gr.Blocks(title="去背服務測試") as app:
gr.Markdown("## ✂️ 去背服務測試")
status_output = gr.JSON()
status_btn = gr.Button("Status", visible=False)
status_btn.click(get_system_stats, inputs=None, outputs=status_output, api_name="status")
with gr.Row():
input_img = gr.Image(type="pil", label="Input")
output_img = gr.Image(type="pil", label="Output", format="png") # 確保輸出 PNG 才有透明度
btn = gr.Button("Remove Background")
btn.click(fn=process_image, inputs=input_img, outputs=output_img)
# --- 4. 啟動 ---
if __name__ == "__main__":
# 新版 Gradio 預設 API 開放 CORS,不需要 cors_allowed_origins
app.launch(server_name="0.0.0.0", server_port=7860)