learn-2 / app.py
urnotwen's picture
Update app.py
bb17445 verified
raw
history blame
2.14 kB
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import timm
import io
import sys
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN,如果是 Gated Model 可能會失敗")
try:
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(f"✅ 模型載入成功!使用裝置: {device}")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 圖像處理 ---
def process_image(input_image):
if input_image is None:
return None
# 處理邏輯
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(input_image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3. 介面 ---
with gr.Blocks(title="RMBG-2.0 AI") as app:
gr.Markdown("## ✂️ 去背服務測試")
with gr.Row():
input_img = gr.Image(type="pil", label="Input")
output_img = gr.Image(type="pil", label="Output", format="png") # 確保輸出 PNG 才有透明度
btn = gr.Button("Remove Background")
btn.click(fn=process_image, inputs=input_img, outputs=output_img)
# --- 4. 啟動 ---
if __name__ == "__main__":
# 新版 Gradio 預設 API 開放 CORS,不需要 cors_allowed_origins
app.launch(server_name="0.0.0.0", server_port=7860)