Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
| import gdown | |
| import os | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| def overlay_mask(image, mask, mask_color=(0, 0, 255), alpha=0.3): | |
| colored_mask = np.zeros_like(image) | |
| colored_mask[mask > 0] = mask_color | |
| overlay_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0) | |
| return overlay_image | |
| # Hàm dự đoán segmentation mask | |
| def predict_segmentation(image): | |
| """ | |
| Predict segmentation mask for input image. | |
| """ | |
| raw_image = np.array(image) | |
| inputs = image_processor(images=raw_image, return_tensors="pt").to(device) | |
| H, W = raw_image.shape[0], raw_image.shape[1] | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W)) | |
| predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy() | |
| overlay = overlay_mask(raw_image,predictions) | |
| return overlay | |
| if __name__ == '__main__': | |
| # Tải file checkpoint nếu chưa tồn tại | |
| url = "https://drive.google.com/uc?id=1zZ3XbfixwiY3Tra78EvD5siMJIF6IvBW&confirm=t&uuid=df1eac8a-fdc0-4438-9a29-202168235570" | |
| output = "Segformer_ISIC2018_epoch_50_model.pth" | |
| if not os.path.exists(output): | |
| gdown.download(url, output, quiet=False) | |
| # Thiết lập thiết bị | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model từ HuggingFace | |
| MODEL_NAME = "nvidia/segformer-b5-finetuned-ade-640-640" | |
| try: | |
| model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME) | |
| except EnvironmentError as e: | |
| print(f"Lỗi khi tải model từ HuggingFace: {e}") | |
| exit() | |
| # Điều chỉnh và tải checkpoint | |
| model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1) | |
| model = model.to(device) | |
| model = torch.nn.DataParallel(model) | |
| # Load checkpoint | |
| checkpoint = torch.load(output, map_location=device,weights_only=True) | |
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| model.eval() | |
| # Image processor | |
| image_processor = SegformerImageProcessor() | |
| # Gradio app | |
| iface = gr.Interface( | |
| fn=predict_segmentation, # Gọi hàm dự đoán | |
| inputs=gr.Image(type="pil"), | |
| outputs="image", | |
| api_name="/predict", | |
| title="Segmentation with Segformer", | |
| description="Upload an image to generate a segmentation mask.", | |
| ) | |
| iface.launch(show_error=True) | |