irebmann's picture
wrapper function to debug
3b5f9ac verified
import gradio as gr
from PIL import Image
import torch
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---- Model IDs / paths ----
PRETRAINED_ID = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" # crystal clear
FINETUNED_ID = "swin2sr_div2k_finetuned_x4_1000steps" # smooth (local folder in repo)
# ---- Load processors ----
processor_pre = AutoImageProcessor.from_pretrained(PRETRAINED_ID)
processor_ft = AutoImageProcessor.from_pretrained(FINETUNED_ID, local_files_only=True)
# ---- Load models ----
model_pre = Swin2SRForImageSuperResolution.from_pretrained(PRETRAINED_ID).to(device)
model_ft = Swin2SRForImageSuperResolution.from_pretrained(FINETUNED_ID, local_files_only=True).to(device)
model_pre.eval()
model_ft.eval()
ABOUT_TEXT = """
**About this tool**
This is a free image enhancement tool that uses Swin2SR to reconstruct and refine image detail at the pixel level (“pixel refabrication”).
I built it because many of my own photos are distorted or only exist as low‑quality copies on social media.
**Tip for social media images**
If your only copy is on Instagram, open Instagram on a computer, view the photo in full‑screen, take a screenshot, and crop it tightly around the image before uploading here.
**Usage notes**
Processing and queue times can occasionally exceed ten minutes, especially during heavy use.
If this tool gets regular traffic, I plan to upgrade to a stronger GPU backend.
Please reach out if you need any help or have questions.
"""
def swin2sr_upscale(input_image: Image.Image, mode: str):
if input_image is None:
return None, "No image uploaded."
try:
input_image = input_image.convert("RGB")
w_lr, h_lr = input_image.size
if mode == "Smooth (Over Smoothing | tuned_1)":
model = model_ft
processor = processor_ft
else:
model = model_pre
processor = processor_pre
inputs = processor(images=input_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1)
sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy())
sr_image = Image.fromarray(sr_array).convert("RGB")
w_sr, h_sr = sr_image.size
msg = (
f"Original resolution: {w_lr}×{h_lr} pixels\n"
f"Super-resolved: {w_sr}×{h_sr} pixels "
f"(scale factors: {w_sr / w_lr:.1f}×, {h_sr / h_lr:.1f}×)"
)
return sr_image, msg
except Exception as e:
return None, f"Error: {type(e).__name__}: {e}"
'''
# ---- Inference function ----
def swin2sr_upscale(input_image: Image.Image, mode: str):
"""
Run 4x super-resolution using Swin2SR.
mode: "Crystal clear (pretrained)" or "Smooth (fine-tuned)".
"""
if input_image is None:
return None
# original size
w_lr, h_lr = input_image.size
if mode == "Smooth (fine-tuned)":
model = model_ft
processor = processor_ft
else:
model = model_pre
processor = processor_pre
inputs = processor(images=input_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1)
sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy())
sr_image = Image.fromarray(sr_array)
# ---ensure JPEG-compatible mode ---
sr_image = sr_image.convert("RGB")
# ----OG and new size ---
w_sr, h_sr = sr_image.size
msg = (
f"Original resolution: {w_lr}×{h_lr} pixels\n"
f"Super-resolved: {w_sr}×{h_sr} pixels "
f"(scale factors: {w_sr / w_lr:.1f}×, {h_sr / h_lr:.1f}×)"
)
return sr_image, msg
'''
# ---- Gradio UI ----
with gr.Blocks() as demo:
gr.Markdown("# Image Super-Resolution (Swin2SR x4)")
gr.Markdown(
"Choose **Clear (pretrained)** for the original Swin2SR model, "
"or **Smooth (fine-tuned)** for the Swin2SR version we fine-tuned on DIV2K patches."
)
with gr.Row():
input_image = gr.Image(type="pil", label="Upload low-res image")
output_image = gr.Image(type="pil",format="jpeg", label="4x Super-resolved image")
res_info = gr.Markdown("Resolution info.")
# ---- info / help box ----
gr.Markdown(ABOUT_TEXT)
mode_dropdown = gr.Dropdown(
label="Style",
choices=["Clear (OFtB)", "Smooth (Over Smoothing | tuned_1)"],
value="Clear (OFtB)",
interactive=True,
)
run_btn = gr.Button("Upscale")
run_btn.click(
fn=swin2sr_upscale,
inputs=[input_image, mode_dropdown],
outputs=[output_image, res_info]
)
if __name__ == "__main__":
demo.launch()