learn-2 / apptest.py
urnotwen's picture
Update apptest.py
299c39e verified
import os
import gradio as gr
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import timm
# --- 🔍 版本檢查區 (請看這裡) ---
import sys
print("="*30)
print(f"Python version: {sys.version}")
print(f"Gradio version: {gr.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Timm version: {timm.__version__}")
print("="*30)
# -----------------------------
# --- 1. 初始化模型 ---
model_id = "briaai/RMBG-2.0"
print(f"正在載入模型: {model_id} ...")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ 警告: 未偵測到 HF_TOKEN")
try:
model = AutoModelForImageSegmentation.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token
)
device = torch.device("cpu")
model.to(device)
model.eval()
print("✅ 模型載入成功!")
except Exception as e:
print(f"❌ 模型載入失敗: {e}")
# --- 2. 圖像處理 (官方邏輯) ---
def process_image(input_image):
if input_image is None:
return None
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(input_image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
image = input_image.convert("RGBA")
image.putalpha(mask)
return image
# --- 3. 介面 ---
# 為了驗證,我們也在網頁上顯示版本
version_info = f"目前運行版本 - Gradio: {gr.__version__} | Torch: {torch.__version__}"
with gr.Blocks(title="版本檢查") as app:
gr.Markdown(f"## ✂️ AI 自動去背")
gr.Markdown(f"ℹ️ **{version_info}**") # 這裡會直接顯示在網頁上
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="上傳圖片")
btn = gr.Button("開始去背")
with gr.Column():
output_img = gr.Image(type="pil", label="去背結果")
btn.click(fn=process_image, inputs=input_img, outputs=output_img)
if __name__ == "__main__":
app.launch()