|
|
import gradio as gr |
|
|
from transformers import AutoModelForImageSegmentation |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
|
|
|
model_id = "briaai/RMBG-2.0" |
|
|
print(f"正在載入模型: {model_id} ...") |
|
|
|
|
|
try: |
|
|
|
|
|
model = AutoModelForImageSegmentation.from_pretrained(model_id, trust_remote_code=True) |
|
|
device = torch.device("cpu") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
print("模型載入成功!") |
|
|
except Exception as e: |
|
|
print(f"模型載入失敗: {e}") |
|
|
|
|
|
|
|
|
def process_image(input_image): |
|
|
if input_image is None: |
|
|
return None |
|
|
|
|
|
|
|
|
orig_w, orig_h = input_image.size |
|
|
|
|
|
|
|
|
transform_image = transforms.Compose([ |
|
|
transforms.Resize((1024, 1024)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
input_tensor = transform_image(input_image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
preds = model(input_tensor)[0][0] |
|
|
|
|
|
preds = torch.nn.functional.interpolate(preds, size=(orig_h, orig_w), mode='bilinear', align_corners=False) |
|
|
preds = torch.sigmoid(preds) |
|
|
|
|
|
|
|
|
mask = preds.squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
mask_img = Image.fromarray((mask * 255).astype('uint8'), mode='L') |
|
|
|
|
|
|
|
|
output_img = input_image.convert("RGBA") |
|
|
output_img.putalpha(mask_img) |
|
|
|
|
|
return output_img |
|
|
|
|
|
|
|
|
|
|
|
pwa_header = """ |
|
|
<head> |
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no"> |
|
|
<meta name="apple-mobile-web-app-capable" content="yes"> |
|
|
<meta name="apple-mobile-web-app-status-bar-style" content="black-translucent"> |
|
|
<meta name="theme-color" content="#0b0f19"> |
|
|
<title>AI 去背神器</title> |
|
|
<style> |
|
|
/* 隱藏 Gradio 預設的頁尾,讓畫面更乾淨 */ |
|
|
footer {display: none !important;} |
|
|
.gradio-container {min-height: 100vh !important;} |
|
|
</style> |
|
|
</head> |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(head=pwa_header, theme=gr.themes.Soft()) as app: |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
# ✂️ AI 自動去背 (RMBG 2.0) |
|
|
上傳照片,自動去除背景。 |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(): |
|
|
input_img = gr.Image(type="pil", label="點擊上傳或拍照", sources=["upload", "clipboard"]) |
|
|
btn = gr.Button("開始去背", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
output_img = gr.Image(type="pil", label="去背結果 (長按儲存)", format="png", show_download_button=True) |
|
|
|
|
|
|
|
|
btn.click(fn=process_image, inputs=input_img, outputs=output_img) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |