File size: 2,442 Bytes
a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d a2e0537 438cd9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
# 檢查 Token
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN,請檢查 Settings 中的 Secret。")
try:
# 載入模型
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
# 強制使用 CPU (避免免費空間報錯)
device = torch.device("cpu")
model.to(device)
model.eval()
print("✅ 模型載入成功!")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 定義圖像處理邏輯 (官方邏輯) ---
def process_image(input_image):
if input_image is None:
return None
# 準備圖像尺寸
image_size = (1024, 1024)
# 定義轉換
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 轉換並增加維度
input_images = transform_image(input_image).unsqueeze(0).to(device)
# 開始預測
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
# 處理預測結果
pred = preds[0].squeeze()
# 轉回 PIL 圖片
pred_pil = transforms.ToPILImage()(pred)
# 調整回原始圖片的大小
mask = pred_pil.resize(input_image.size)
# 合成去背圖
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3. 建立介面 (移除導致錯誤的 head 參數) ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("## ✂️ AI 自動去背 (RMBG 2.0)")
gr.Markdown("由於版本相容性,目前以標準網頁模式運行。")
with gr.Column():
input_img = gr.Image(type="pil", label="上傳圖片", sources=["upload", "clipboard", "webcam"])
btn = gr.Button("✨ 開始去背", variant="primary", size="lg")
output_img = gr.Image(type="pil", label="去背結果", format="png", show_download_button=True)
btn.click(fn=process_image, inputs=input_img, outputs=output_img)
if __name__ == "__main__":
app.launch() |