import gradio as gr from PIL import Image import torch from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution # ---- Device ---- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ---- Model IDs / paths ---- PRETRAINED_ID = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" # crystal clear FINETUNED_ID = "swin2sr_div2k_finetuned_x4_1000steps" # smooth (local folder in repo) # ---- Load processors ---- processor_pre = AutoImageProcessor.from_pretrained(PRETRAINED_ID) processor_ft = AutoImageProcessor.from_pretrained(FINETUNED_ID, local_files_only=True) # ---- Load models ---- model_pre = Swin2SRForImageSuperResolution.from_pretrained(PRETRAINED_ID).to(device) model_ft = Swin2SRForImageSuperResolution.from_pretrained(FINETUNED_ID, local_files_only=True).to(device) model_pre.eval() model_ft.eval() ABOUT_TEXT = """ **About this tool** This is a free image enhancement tool that uses Swin2SR to reconstruct and refine image detail at the pixel level (“pixel refabrication”). I built it because many of my own photos are distorted or only exist as low‑quality copies on social media. **Tip for social media images** If your only copy is on Instagram, open Instagram on a computer, view the photo in full‑screen, take a screenshot, and crop it tightly around the image before uploading here. **Usage notes** Processing and queue times can occasionally exceed ten minutes, especially during heavy use. If this tool gets regular traffic, I plan to upgrade to a stronger GPU backend. Please reach out if you need any help or have questions. """ def swin2sr_upscale(input_image: Image.Image, mode: str): if input_image is None: return None, "No image uploaded." try: input_image = input_image.convert("RGB") w_lr, h_lr = input_image.size if mode == "Smooth (Over Smoothing | tuned_1)": model = model_ft processor = processor_ft else: model = model_pre processor = processor_pre inputs = processor(images=input_image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1) sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy()) sr_image = Image.fromarray(sr_array).convert("RGB") w_sr, h_sr = sr_image.size msg = ( f"Original resolution: {w_lr}×{h_lr} pixels\n" f"Super-resolved: {w_sr}×{h_sr} pixels " f"(scale factors: {w_sr / w_lr:.1f}×, {h_sr / h_lr:.1f}×)" ) return sr_image, msg except Exception as e: return None, f"Error: {type(e).__name__}: {e}" ''' # ---- Inference function ---- def swin2sr_upscale(input_image: Image.Image, mode: str): """ Run 4x super-resolution using Swin2SR. mode: "Crystal clear (pretrained)" or "Smooth (fine-tuned)". """ if input_image is None: return None # original size w_lr, h_lr = input_image.size if mode == "Smooth (fine-tuned)": model = model_ft processor = processor_ft else: model = model_pre processor = processor_pre inputs = processor(images=input_image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1) sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy()) sr_image = Image.fromarray(sr_array) # ---ensure JPEG-compatible mode --- sr_image = sr_image.convert("RGB") # ----OG and new size --- w_sr, h_sr = sr_image.size msg = ( f"Original resolution: {w_lr}×{h_lr} pixels\n" f"Super-resolved: {w_sr}×{h_sr} pixels " f"(scale factors: {w_sr / w_lr:.1f}×, {h_sr / h_lr:.1f}×)" ) return sr_image, msg ''' # ---- Gradio UI ---- with gr.Blocks() as demo: gr.Markdown("# Image Super-Resolution (Swin2SR x4)") gr.Markdown( "Choose **Clear (pretrained)** for the original Swin2SR model, " "or **Smooth (fine-tuned)** for the Swin2SR version we fine-tuned on DIV2K patches." ) with gr.Row(): input_image = gr.Image(type="pil", label="Upload low-res image") output_image = gr.Image(type="pil",format="jpeg", label="4x Super-resolved image") res_info = gr.Markdown("Resolution info.") # ---- info / help box ---- gr.Markdown(ABOUT_TEXT) mode_dropdown = gr.Dropdown( label="Style", choices=["Clear (OFtB)", "Smooth (Over Smoothing | tuned_1)"], value="Clear (OFtB)", interactive=True, ) run_btn = gr.Button("Upscale") run_btn.click( fn=swin2sr_upscale, inputs=[input_image, mode_dropdown], outputs=[output_image, res_info] ) if __name__ == "__main__": demo.launch()