File size: 2,401 Bytes
5ff4dd4
 
 
 
 
 
 
 
299c39e
 
 
 
 
 
 
 
 
 
5ff4dd4
 
 
 
 
299c39e
5ff4dd4
 
 
 
 
299c39e
 
 
5ff4dd4
299c39e
5ff4dd4
 
 
 
 
 
299c39e
5ff4dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299c39e
 
 
5ff4dd4
299c39e
5ff4dd4
299c39e
 
 
 
 
 
 
 
 
5ff4dd4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import timm

# --- 🔍 版本檢查區 (請看這裡) ---
import sys
print("="*30)
print(f"Python version: {sys.version}")
print(f"Gradio version: {gr.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Timm version: {timm.__version__}")
print("="*30)
# -----------------------------

# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")

hf_token = os.getenv("HF_TOKEN")

if not hf_token:
    print("⚠️ 警告: 未偵測到 HF_TOKEN")

try:
    model = AutoModelForImageSegmentation.from_pretrained(
        model_id, 
        trust_remote_code=True, 
        token=hf_token
    )
    device = torch.device("cpu") 
    model.to(device)
    model.eval()
    print("✅ 模型載入成功!")
except Exception as e:
    print(f"❌ 模型載入失敗: {e}")

# --- 2. 圖像處理 (官方邏輯) ---
def process_image(input_image):
    if input_image is None:
        return None
    image_size = (1024, 1024)
    transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    input_images = transform_image(input_image).unsqueeze(0).to(device)
    with torch.no_grad():
        preds = model(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(input_image.size)
    image = input_image.convert("RGBA")
    image.putalpha(mask)
    return image

# --- 3. 介面 ---
# 為了驗證,我們也在網頁上顯示版本
version_info = f"目前運行版本 - Gradio: {gr.__version__} | Torch: {torch.__version__}"

with gr.Blocks(title="版本檢查") as app:
    
    gr.Markdown(f"## ✂️ AI 自動去背")
    gr.Markdown(f"ℹ️ **{version_info}**") # 這裡會直接顯示在網頁上
    
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(type="pil", label="上傳圖片")
            btn = gr.Button("開始去背")
        with gr.Column():
            output_img = gr.Image(type="pil", label="去背結果")

    btn.click(fn=process_image, inputs=input_img, outputs=output_img)

if __name__ == "__main__":
    app.launch()